/* * MIT License * * Copyright (c) 2019 Alexey Edelev <semlanik@gmail.com> * * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork * * Permission is hereby granted, free of charge, to any person obtaining a copy of this * software and associated documentation files (the "Software"), to deal in the Software * without restriction, including without limitation the rights to use, copy, modify, * merge, publish, distribute, sublicense, and/or sell copies of the Software, and * to permit persons to whom the Software is furnished to do so, subject to the following * conditions: * * The above copyright notice and this permission notice shall be included in all copies * or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER * DEALINGS IN THE SOFTWARE. */ package handwriting import ( "bytes" context "context" fmt "fmt" "net" earlystop "git.semlanik.org/semlanik/NeuralNetwork/earlystop" neuralnetwork "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork" gradients "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork/gradients" training "git.semlanik.org/semlanik/NeuralNetwork/training" mat "gonum.org/v1/gonum/mat" grpc "google.golang.org/grpc" ) type HandwritingService struct { nn *neuralnetwork.NeuralNetwork } func NewHandwritingService() (hws *HandwritingService) { hws = &HandwritingService{} hws.nn, _ = neuralnetwork.NewNeuralNetwork([]int{784, 300, 10}, gradients.NewRPropInitializer(gradients.RPropConfig{ NuPlus: 1.2, NuMinus: 0.5, DeltaMax: 50.0, DeltaMin: 0.000001, })) hws.nn.SetStateWatcher(hws) return } func (hws *HandwritingService) Recognize(ctx context.Context, matrix *Matrix) (*Result, error) { fmt.Printf("Recognize %v size: %v\n", len(matrix.Data), hws.nn.Sizes[0]) dense := mat.NewDense(hws.nn.Sizes[0], 1, matrix.Data) index, _ := hws.nn.Predict(dense) fmt.Printf("Recognition result %v\n", index) return &Result{ResultCharacter: uint32(index)}, nil } func (hws *HandwritingService) SetNeuralNetworkData(ctx context.Context, nnRaw *NeuralNetworkRaw) (*None, error) { fmt.Println("SetNeuralNetworkData") r := bytes.NewReader(nnRaw.Data) hws.nn.LoadState(r) return &None{}, nil } func (hws *HandwritingService) GetNeuralNetworkData(context.Context, *None) (*NeuralNetworkRaw, error) { nnRaw := &NeuralNetworkRaw{} fmt.Println("SetNeuralNetworkData") r := bytes.NewReader(nnRaw.Data) hws.nn.LoadState(r) return nnRaw, nil } func (hws *HandwritingService) ReTrain(context.Context, *None) (*None, error) { fmt.Println("ReTrain") trainer := training.NewMNISTReader("./train-images-idx3-ubyte", "./train-labels-idx1-ubyte", "./t10k-images-idx3-ubyte", "./t10k-labels-idx1-ubyte") hws.nn.SetEarlyStop(earlystop.NewSimpleDescentEarlyStop(hws.nn, trainer)) squareError, failCount, total := hws.nn.Validate(trainer) fmt.Printf("Fail count before: %v/%v, error: %v\n\n", failCount, total, squareError) hws.nn.Train(trainer, 100) hws.nn.SaveStateToFile("./mnistnet.nnd") squareError, failCount, total = hws.nn.Validate(trainer) fmt.Printf("Fail count after: %v/%v, error: %v\n\n", failCount, total, squareError) fmt.Println("ReTrain finished") return &None{}, nil } func (hws *HandwritingService) Run() { grpcServer := grpc.NewServer() RegisterHandwritingServer(grpcServer, hws) lis, err := net.Listen("tcp", "localhost:65001") if err != nil { fmt.Printf("Failed to listen: %v\n", err) } fmt.Printf("Listen localhost:65001\n") if err := grpcServer.Serve(lis); err != nil { fmt.Printf("Failed to serve: %v\n", err) } } func (hws *HandwritingService) Init(nn *neuralnetwork.NeuralNetwork) { } func (hws *HandwritingService) UpdateState(int) { } func (hws *HandwritingService) UpdateActivations(int, *mat.Dense) { } func (hws *HandwritingService) UpdateBiases(int, *mat.Dense) { } func (hws *HandwritingService) UpdateWeights(int, *mat.Dense) { } func (hws *HandwritingService) UpdateTraining(t int, epocs int, samplesProcced int, totalSamplesCount int) { fmt.Printf("Training progress: Epoc: %v/%v\n", t, epocs) } func (hws *HandwritingService) UpdateValidation(validatorCount int, failCount int) { } func (hws *HandwritingService) GetSubscriptionFeatures() (features neuralnetwork.SubscriptionFeatures) { features = 0 features.Set(neuralnetwork.TrainingSubscription) features.Set(neuralnetwork.ValidationSubscription) return } func drawImage(dense *mat.Dense) { for i := 0; i < 28; i++ { for j := 0; j < 28; j++ { val := 0 if dense.At(i*28+j, 0) > 0 { val = 1 } fmt.Printf("%v ", val) } fmt.Println() } }