Browse Source

Add descent early stop usage

TODO: not really working well
Alexey Edelev 4 years ago
parent
commit
250280385a
1 changed files with 9 additions and 5 deletions
  1. 9 5
      handwriting/handwriting.go

+ 9 - 5
handwriting/handwriting.go

@@ -31,10 +31,12 @@ import (
 	fmt "fmt"
 	"net"
 
+	earlystop "git.semlanik.org/semlanik/NeuralNetwork/earlystop"
 	neuralnetwork "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork"
 	gradients "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork/gradients"
 	training "git.semlanik.org/semlanik/NeuralNetwork/training"
-	"gonum.org/v1/gonum/mat"
+
+	mat "gonum.org/v1/gonum/mat"
 	grpc "google.golang.org/grpc"
 )
 
@@ -82,15 +84,17 @@ func (hws *HandwritingService) ReTrain(context.Context, *None) (*None, error) {
 	fmt.Println("ReTrain")
 
 	trainer := training.NewMNISTReader("./train-images-idx3-ubyte", "./train-labels-idx1-ubyte", "./t10k-images-idx3-ubyte", "./t10k-labels-idx1-ubyte")
-	failCount, total := hws.nn.Validate(trainer)
-	fmt.Printf("Fail count before: %v/%v\n\n", failCount, total)
+	hws.nn.SetEarlyStop(earlystop.NewSimpleDescentEarlyStop(hws.nn, trainer))
+
+	squareError, failCount, total := hws.nn.Validate(trainer)
+	fmt.Printf("Fail count before: %v/%v, error: %v\n\n", failCount, total, squareError)
 
 	hws.nn.Train(trainer, 100)
 
 	hws.nn.SaveStateToFile("./mnistnet.nnd")
 
-	failCount, total = hws.nn.Validate(trainer)
-	fmt.Printf("Fail count after: %v/%v\n\n", failCount, total)
+	squareError, failCount, total = hws.nn.Validate(trainer)
+	fmt.Printf("Fail count after: %v/%v, error: %v\n\n", failCount, total, squareError)
 
 	fmt.Println("ReTrain finished")
 	return &None{}, nil