|
@@ -27,17 +27,22 @@ package earlystop
|
|
|
|
|
|
import (
|
|
import (
|
|
"log"
|
|
"log"
|
|
|
|
+ "math"
|
|
|
|
+ "os"
|
|
|
|
|
|
neuralnetwork "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork"
|
|
neuralnetwork "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork"
|
|
training "git.semlanik.org/semlanik/NeuralNetwork/training"
|
|
training "git.semlanik.org/semlanik/NeuralNetwork/training"
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+const tmpFileName = "./.simpleDescentEarlyStop.nnd.tmp"
|
|
|
|
+
|
|
type simpleDescentEarlyStop struct {
|
|
type simpleDescentEarlyStop struct {
|
|
lastFailRate float64
|
|
lastFailRate float64
|
|
bestFailRate float64
|
|
bestFailRate float64
|
|
failRateDeltaSum float64
|
|
failRateDeltaSum float64
|
|
network *neuralnetwork.NeuralNetwork
|
|
network *neuralnetwork.NeuralNetwork
|
|
trainer training.Trainer
|
|
trainer training.Trainer
|
|
|
|
+ glGrowCount int
|
|
}
|
|
}
|
|
|
|
|
|
func NewSimpleDescentEarlyStop(network *neuralnetwork.NeuralNetwork, trainer training.Trainer) (es *simpleDescentEarlyStop) {
|
|
func NewSimpleDescentEarlyStop(network *neuralnetwork.NeuralNetwork, trainer training.Trainer) (es *simpleDescentEarlyStop) {
|
|
@@ -47,41 +52,47 @@ func NewSimpleDescentEarlyStop(network *neuralnetwork.NeuralNetwork, trainer tra
|
|
}
|
|
}
|
|
|
|
|
|
es = &simpleDescentEarlyStop{
|
|
es = &simpleDescentEarlyStop{
|
|
- lastFailRate: 1.0,
|
|
|
|
- bestFailRate: 1.0,
|
|
|
|
|
|
+ lastFailRate: math.MaxFloat64,
|
|
|
|
+ bestFailRate: math.MaxFloat64,
|
|
failRateDeltaSum: 0.0,
|
|
failRateDeltaSum: 0.0,
|
|
network: network,
|
|
network: network,
|
|
trainer: trainer,
|
|
trainer: trainer,
|
|
|
|
+ glGrowCount: 0,
|
|
}
|
|
}
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
func (es *simpleDescentEarlyStop) Test() bool {
|
|
func (es *simpleDescentEarlyStop) Test() bool {
|
|
squareError, fails, total := es.network.Validate(es.trainer)
|
|
squareError, fails, total := es.network.Validate(es.trainer)
|
|
- log.Printf("Fail count: %v/%v, error: %v\n", fails, total, squareError)
|
|
|
|
|
|
|
|
- failRate := float64(fails) / float64(total)
|
|
|
|
- failRateDelta := failRate - es.lastFailRate
|
|
|
|
- log.Printf("failRate %v lastFailRate %v failRateDelta %v \n", failRate, es.lastFailRate, failRateDelta)
|
|
|
|
|
|
+ es.lastFailRate = squareError / float64(total)
|
|
|
|
|
|
- es.lastFailRate = failRate
|
|
|
|
|
|
+ log.Printf("Fail count: %v/%v, lastFailRate: %v\n", fails, total, es.lastFailRate)
|
|
|
|
|
|
- if failRateDelta > 0 { //Positive failRateDelta cause fail rate grow, accumulate total grow
|
|
|
|
- es.failRateDeltaSum += failRateDelta
|
|
|
|
- } else { //Reset failRateDeltaSum in case if we step over one of local maximum
|
|
|
|
- es.failRateDeltaSum = 0.0
|
|
|
|
- }
|
|
|
|
|
|
+ generalizationLoss := (es.lastFailRate/es.bestFailRate - 1.0)
|
|
|
|
|
|
if es.bestFailRate > es.lastFailRate {
|
|
if es.bestFailRate > es.lastFailRate {
|
|
es.bestFailRate = es.lastFailRate
|
|
es.bestFailRate = es.lastFailRate
|
|
- //TODO: save neuralnetwork state at this point
|
|
|
|
|
|
+ es.network.SaveStateToFile(tmpFileName)
|
|
}
|
|
}
|
|
|
|
|
|
- return false //es.failRateDeltaSum > 0.05
|
|
|
|
|
|
+ if generalizationLoss > 0.0 {
|
|
|
|
+ es.glGrowCount++
|
|
|
|
+ } else {
|
|
|
|
+ es.glGrowCount = 0
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if es.glGrowCount > 5 {
|
|
|
|
+ es.network.LoadStateFromFile(tmpFileName)
|
|
|
|
+ os.Remove(tmpFileName)
|
|
|
|
+ return true
|
|
|
|
+ }
|
|
|
|
+ return false
|
|
}
|
|
}
|
|
|
|
|
|
func (es *simpleDescentEarlyStop) Reset() {
|
|
func (es *simpleDescentEarlyStop) Reset() {
|
|
- es.lastFailRate = 1.0
|
|
|
|
- es.bestFailRate = 1.0
|
|
|
|
|
|
+ es.lastFailRate = math.MaxFloat64
|
|
|
|
+ es.bestFailRate = math.MaxFloat64
|
|
|
|
+ es.glGrowCount = 0
|
|
es.failRateDeltaSum = 0.0
|
|
es.failRateDeltaSum = 0.0
|
|
}
|
|
}
|