123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- package handwriting
- import (
- "bytes"
- context "context"
- fmt "fmt"
- "net"
- earlystop "git.semlanik.org/semlanik/NeuralNetwork/earlystop"
- neuralnetwork "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork"
- gradients "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork/gradients"
- training "git.semlanik.org/semlanik/NeuralNetwork/training"
- mat "gonum.org/v1/gonum/mat"
- grpc "google.golang.org/grpc"
- )
- type HandwritingService struct {
- nn *neuralnetwork.NeuralNetwork
- }
- func NewHandwritingService() (hws *HandwritingService) {
- hws = &HandwritingService{}
- hws.nn, _ = neuralnetwork.NewNeuralNetwork([]int{784, 300, 10}, gradients.NewRPropInitializer(gradients.RPropConfig{
- NuPlus: 1.2,
- NuMinus: 0.5,
- DeltaMax: 50.0,
- DeltaMin: 0.000001,
- }))
- hws.nn.SetStateWatcher(hws)
- return
- }
- func (hws *HandwritingService) Recognize(ctx context.Context, matrix *Matrix) (*Result, error) {
- fmt.Printf("Recognize %v size: %v\n", len(matrix.Data), hws.nn.Sizes[0])
- dense := mat.NewDense(hws.nn.Sizes[0], 1, matrix.Data)
- index, _ := hws.nn.Predict(dense)
- fmt.Printf("Recognition result %v\n", index)
- return &Result{ResultCharacter: uint32(index)}, nil
- }
- func (hws *HandwritingService) SetNeuralNetworkData(ctx context.Context, nnRaw *NeuralNetworkRaw) (*None, error) {
- fmt.Println("SetNeuralNetworkData")
- r := bytes.NewReader(nnRaw.Data)
- hws.nn.LoadState(r)
- return &None{}, nil
- }
- func (hws *HandwritingService) GetNeuralNetworkData(context.Context, *None) (*NeuralNetworkRaw, error) {
- nnRaw := &NeuralNetworkRaw{}
- fmt.Println("SetNeuralNetworkData")
- r := bytes.NewReader(nnRaw.Data)
- hws.nn.LoadState(r)
- return nnRaw, nil
- }
- func (hws *HandwritingService) ReTrain(context.Context, *None) (*None, error) {
- fmt.Println("ReTrain")
- trainer := training.NewMNISTReader("./train-images-idx3-ubyte", "./train-labels-idx1-ubyte", "./t10k-images-idx3-ubyte", "./t10k-labels-idx1-ubyte")
- hws.nn.SetEarlyStop(earlystop.NewSimpleDescentEarlyStop(hws.nn, trainer))
- squareError, failCount, total := hws.nn.Validate(trainer)
- fmt.Printf("Fail count before: %v/%v, error: %v\n\n", failCount, total, squareError)
- hws.nn.Train(trainer, 100)
- hws.nn.SaveStateToFile("./mnistnet.nnd")
- squareError, failCount, total = hws.nn.Validate(trainer)
- fmt.Printf("Fail count after: %v/%v, error: %v\n\n", failCount, total, squareError)
- fmt.Println("ReTrain finished")
- return &None{}, nil
- }
- func (hws *HandwritingService) Run() {
- grpcServer := grpc.NewServer()
- RegisterHandwritingServer(grpcServer, hws)
- lis, err := net.Listen("tcp", "localhost:65001")
- if err != nil {
- fmt.Printf("Failed to listen: %v\n", err)
- }
- fmt.Printf("Listen localhost:65001\n")
- if err := grpcServer.Serve(lis); err != nil {
- fmt.Printf("Failed to serve: %v\n", err)
- }
- }
- func (hws *HandwritingService) Init(nn *neuralnetwork.NeuralNetwork) {
- }
- func (hws *HandwritingService) UpdateState(int) {
- }
- func (hws *HandwritingService) UpdateActivations(int, *mat.Dense) {
- }
- func (hws *HandwritingService) UpdateBiases(int, *mat.Dense) {
- }
- func (hws *HandwritingService) UpdateWeights(int, *mat.Dense) {
- }
- func (hws *HandwritingService) UpdateTraining(t int, epocs int, samplesProcced int, totalSamplesCount int) {
- fmt.Printf("Training progress: Epoc: %v/%v\n", t, epocs)
- }
- func (hws *HandwritingService) UpdateValidation(validatorCount int, failCount int) {
- }
- func (hws *HandwritingService) GetSubscriptionFeatures() (features neuralnetwork.SubscriptionFeatures) {
- features = 0
- features.Set(neuralnetwork.TrainingSubscription)
- features.Set(neuralnetwork.ValidationSubscription)
- return
- }
- func drawImage(dense *mat.Dense) {
- for i := 0; i < 28; i++ {
- for j := 0; j < 28; j++ {
- val := 0
- if dense.At(i*28+j, 0) > 0 {
- val = 1
- }
- fmt.Printf("%v ", val)
- }
- fmt.Println()
- }
- }
|