handwriting.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. /*
  2. * MIT License
  3. *
  4. * Copyright (c) 2019 Alexey Edelev <semlanik@gmail.com>
  5. *
  6. * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy of this
  9. * software and associated documentation files (the "Software"), to deal in the Software
  10. * without restriction, including without limitation the rights to use, copy, modify,
  11. * merge, publish, distribute, sublicense, and/or sell copies of the Software, and
  12. * to permit persons to whom the Software is furnished to do so, subject to the following
  13. * conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all copies
  16. * or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
  19. * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
  20. * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
  21. * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
  22. * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
  23. * DEALINGS IN THE SOFTWARE.
  24. */
  25. package handwriting
  26. import (
  27. "bytes"
  28. context "context"
  29. fmt "fmt"
  30. "net"
  31. neuralnetwork "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork"
  32. gradients "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork/gradients"
  33. training "git.semlanik.org/semlanik/NeuralNetwork/training"
  34. "gonum.org/v1/gonum/mat"
  35. grpc "google.golang.org/grpc"
  36. )
  37. type HandwritingService struct {
  38. nn *neuralnetwork.NeuralNetwork
  39. }
  40. func NewHandwritingService() (hws *HandwritingService) {
  41. hws = &HandwritingService{}
  42. hws.nn, _ = neuralnetwork.NewNeuralNetwork([]int{784, 300, 10}, gradients.NewRPropInitializer(gradients.RPropConfig{
  43. NuPlus: 1.2,
  44. NuMinus: 0.5,
  45. DeltaMax: 50.0,
  46. DeltaMin: 0.000001,
  47. }))
  48. return
  49. }
  50. func (hws *HandwritingService) Recognize(ctx context.Context, matrix *Matrix) (*Result, error) {
  51. fmt.Printf("Recognize %v size: %v\n", len(matrix.Data), hws.nn.Sizes[0])
  52. dense := mat.NewDense(hws.nn.Sizes[0], 1, matrix.Data)
  53. index, _ := hws.nn.Predict(dense)
  54. fmt.Printf("Recognition result %v\n", index)
  55. return &Result{ResultCharacter: uint32(index)}, nil
  56. }
  57. func (hws *HandwritingService) SetNeuralNetworkData(ctx context.Context, nnRaw *NeuralNetworkRaw) (*None, error) {
  58. fmt.Println("SetNeuralNetworkData")
  59. r := bytes.NewReader(nnRaw.Data)
  60. hws.nn.LoadState(r)
  61. return &None{}, nil
  62. }
  63. func (hws *HandwritingService) GetNeuralNetworkData(context.Context, *None) (*NeuralNetworkRaw, error) {
  64. nnRaw := &NeuralNetworkRaw{}
  65. fmt.Println("SetNeuralNetworkData")
  66. r := bytes.NewReader(nnRaw.Data)
  67. hws.nn.LoadState(r)
  68. return nnRaw, nil
  69. }
  70. func (hws *HandwritingService) ReTrain(context.Context, *None) (*None, error) {
  71. fmt.Println("ReTrain")
  72. trainer := training.NewMNISTReader("./train-images-idx3-ubyte", "./train-labels-idx1-ubyte", "./t10k-images-idx3-ubyte", "./t10k-labels-idx1-ubyte")
  73. failCount, total := hws.nn.Validate(trainer)
  74. fmt.Printf("Fail count before: %v/%v\n\n", failCount, total)
  75. hws.nn.Train(trainer, 100)
  76. hws.nn.SaveStateToFile("./mnistnet.nnd")
  77. failCount, total = hws.nn.Validate(trainer)
  78. fmt.Printf("Fail count after: %v/%v\n\n", failCount, total)
  79. fmt.Println("ReTrain finished")
  80. return &None{}, nil
  81. }
  82. func (hws *HandwritingService) Run() {
  83. grpcServer := grpc.NewServer()
  84. RegisterHandwritingServer(grpcServer, hws)
  85. lis, err := net.Listen("tcp", "localhost:65001")
  86. if err != nil {
  87. fmt.Printf("Failed to listen: %v\n", err)
  88. }
  89. fmt.Printf("Listen localhost:65001\n")
  90. if err := grpcServer.Serve(lis); err != nil {
  91. fmt.Printf("Failed to serve: %v\n", err)
  92. }
  93. }
  94. func drawImage(dense *mat.Dense) {
  95. for i := 0; i < 28; i++ {
  96. for j := 0; j < 28; j++ {
  97. val := 0
  98. if dense.At(i*28+j, 0) > 0 {
  99. val = 1
  100. }
  101. fmt.Printf("%v ", val)
  102. }
  103. fmt.Println()
  104. }
  105. }