|
@@ -31,10 +31,12 @@ import (
|
|
|
fmt "fmt"
|
|
|
"net"
|
|
|
|
|
|
+ earlystop "git.semlanik.org/semlanik/NeuralNetwork/earlystop"
|
|
|
neuralnetwork "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork"
|
|
|
gradients "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork/gradients"
|
|
|
training "git.semlanik.org/semlanik/NeuralNetwork/training"
|
|
|
- "gonum.org/v1/gonum/mat"
|
|
|
+
|
|
|
+ mat "gonum.org/v1/gonum/mat"
|
|
|
grpc "google.golang.org/grpc"
|
|
|
)
|
|
|
|
|
@@ -82,15 +84,17 @@ func (hws *HandwritingService) ReTrain(context.Context, *None) (*None, error) {
|
|
|
fmt.Println("ReTrain")
|
|
|
|
|
|
trainer := training.NewMNISTReader("./train-images-idx3-ubyte", "./train-labels-idx1-ubyte", "./t10k-images-idx3-ubyte", "./t10k-labels-idx1-ubyte")
|
|
|
- failCount, total := hws.nn.Validate(trainer)
|
|
|
- fmt.Printf("Fail count before: %v/%v\n\n", failCount, total)
|
|
|
+ hws.nn.SetEarlyStop(earlystop.NewSimpleDescentEarlyStop(hws.nn, trainer))
|
|
|
+
|
|
|
+ squareError, failCount, total := hws.nn.Validate(trainer)
|
|
|
+ fmt.Printf("Fail count before: %v/%v, error: %v\n\n", failCount, total, squareError)
|
|
|
|
|
|
hws.nn.Train(trainer, 100)
|
|
|
|
|
|
hws.nn.SaveStateToFile("./mnistnet.nnd")
|
|
|
|
|
|
- failCount, total = hws.nn.Validate(trainer)
|
|
|
- fmt.Printf("Fail count after: %v/%v\n\n", failCount, total)
|
|
|
+ squareError, failCount, total = hws.nn.Validate(trainer)
|
|
|
+ fmt.Printf("Fail count after: %v/%v, error: %v\n\n", failCount, total, squareError)
|
|
|
|
|
|
fmt.Println("ReTrain finished")
|
|
|
return &None{}, nil
|