Browse Source

Update descent early stop

- Change descent early stop criteria
- Add early stop test run before training
Alexey Edelev 4 years ago
parent
commit
cb9f393f25

+ 2 - 30
batchworker/remotebatchworker.proto → batchworker/remotebatchworker.go

@@ -23,35 +23,7 @@
  * DEALINGS IN THE SOFTWARE.
  */
 
- syntax="proto3";
+package batchworker
 
-package remotecontrol;
-
-message None {
-}
-
-message Url {
-    string url = 1;
-}
-
-message Urls {
-    repeated Url list = 1;
-}
-
-message Matrix {
-    bytes matrix = 1;
-}
-
-message DataSet {
-    repeated Matrix data = 1;
-    repeated Matrix result = 2;
-}
-
-service NeuralNetworkCluster {
-    rpc Register(Url) returns (None) {}
-    rpc Providers(None) returns (stream Urls) {}
-}
-
-service BatchWorkerProvider {
-    rpc Run(DataSet) returns (Matrix) {}
+type RemoteBatchWorker struct {
 }

+ 27 - 16
earlystop/simpledescentearlystop.go

@@ -27,17 +27,22 @@ package earlystop
 
 import (
 	"log"
+	"math"
+	"os"
 
 	neuralnetwork "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork"
 	training "git.semlanik.org/semlanik/NeuralNetwork/training"
 )
 
+const tmpFileName = "./.simpleDescentEarlyStop.nnd.tmp"
+
 type simpleDescentEarlyStop struct {
 	lastFailRate     float64
 	bestFailRate     float64
 	failRateDeltaSum float64
 	network          *neuralnetwork.NeuralNetwork
 	trainer          training.Trainer
+	glGrowCount      int
 }
 
 func NewSimpleDescentEarlyStop(network *neuralnetwork.NeuralNetwork, trainer training.Trainer) (es *simpleDescentEarlyStop) {
@@ -47,41 +52,47 @@ func NewSimpleDescentEarlyStop(network *neuralnetwork.NeuralNetwork, trainer tra
 	}
 
 	es = &simpleDescentEarlyStop{
-		lastFailRate:     1.0,
-		bestFailRate:     1.0,
+		lastFailRate:     math.MaxFloat64,
+		bestFailRate:     math.MaxFloat64,
 		failRateDeltaSum: 0.0,
 		network:          network,
 		trainer:          trainer,
+		glGrowCount:      0,
 	}
 	return
 }
 
 func (es *simpleDescentEarlyStop) Test() bool {
 	squareError, fails, total := es.network.Validate(es.trainer)
-	log.Printf("Fail count: %v/%v, error: %v\n", fails, total, squareError)
 
-	failRate := float64(fails) / float64(total)
-	failRateDelta := failRate - es.lastFailRate
-	log.Printf("failRate %v lastFailRate %v failRateDelta %v \n", failRate, es.lastFailRate, failRateDelta)
+	es.lastFailRate = squareError / float64(total)
 
-	es.lastFailRate = failRate
+	log.Printf("Fail count: %v/%v, lastFailRate: %v\n", fails, total, es.lastFailRate)
 
-	if failRateDelta > 0 { //Positive failRateDelta cause fail rate grow, accumulate total grow
-		es.failRateDeltaSum += failRateDelta
-	} else { //Reset failRateDeltaSum in case if we step over one of local maximum
-		es.failRateDeltaSum = 0.0
-	}
+	generalizationLoss := (es.lastFailRate/es.bestFailRate - 1.0)
 
 	if es.bestFailRate > es.lastFailRate {
 		es.bestFailRate = es.lastFailRate
-		//TODO: save neuralnetwork state at this point
+		es.network.SaveStateToFile(tmpFileName)
 	}
 
-	return false //es.failRateDeltaSum > 0.05
+	if generalizationLoss > 0.0 {
+		es.glGrowCount++
+	} else {
+		es.glGrowCount = 0
+	}
+
+	if es.glGrowCount > 5 {
+		es.network.LoadStateFromFile(tmpFileName)
+		os.Remove(tmpFileName)
+		return true
+	}
+	return false
 }
 
 func (es *simpleDescentEarlyStop) Reset() {
-	es.lastFailRate = 1.0
-	es.bestFailRate = 1.0
+	es.lastFailRate = math.MaxFloat64
+	es.bestFailRate = math.MaxFloat64
+	es.glGrowCount = 0
 	es.failRateDeltaSum = 0.0
 }

+ 2 - 1
neuralnetwork/neuralnetwork.go

@@ -348,6 +348,7 @@ func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
 
 	if nn.earlyStop != nil {
 		nn.earlyStop.Reset()
+		nn.earlyStop.Test()
 	}
 
 	if _, ok := nn.WGradient[nn.LayerCount-1].(OnlineGradientDescent); ok {
@@ -519,7 +520,7 @@ func (nn *NeuralNetwork) SaveStateToFile(filePath string) {
 func (nn *NeuralNetwork) LoadState(reader io.Reader) {
 	nn.syncMutex.Lock()
 	defer nn.syncMutex.Unlock()
-	// Reade count
+	// Read count
 	nn.LayerCount = readInt(reader)
 
 	// Read an input array

+ 1 - 1
remotecontrol/remotecontrol.proto

@@ -23,7 +23,7 @@
  * DEALINGS IN THE SOFTWARE.
  */
 
- syntax="proto3";
+syntax = "proto3";
 
 package remotecontrol;