Browse Source

Add EarlyStop interface and BratchWorker interfaces

- Add Remote batch workers interface(not finalized yet)
- Add EarlyStop interface
- Add reference early stop implementations
- Make local batch worker used default
- Rework Validate function. It doesn't use standard predict anymore
  and provides mean squared error
Alexey Edelev 4 years ago
parent
commit
b7030e0f47

+ 57 - 0
batchworker/remotebatchworker.proto

@@ -0,0 +1,57 @@
+/*
+ * MIT License
+ *
+ * Copyright (c) 2020 Alexey Edelev <semlanik@gmail.com>
+ *
+ * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and
+ * to permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
+ * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+ * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ * DEALINGS IN THE SOFTWARE.
+ */
+
+ syntax="proto3";
+
+package remotecontrol;
+
+message None {
+}
+
+message Url {
+    string url = 1;
+}
+
+message Urls {
+    repeated Url list = 1;
+}
+
+message Matrix {
+    bytes matrix = 1;
+}
+
+message DataSet {
+    repeated Matrix data = 1;
+    repeated Matrix result = 2;
+}
+
+service NeuralNetworkCluster {
+    rpc Register(Url) returns (None) {}
+    rpc Providers(None) returns (stream Urls) {}
+}
+
+service BatchWorkerProvider {
+    rpc Run(DataSet) returns (Matrix) {}
+}

+ 64 - 0
earlystop/constantrateearlystop.go

@@ -0,0 +1,64 @@
+/*
+ * MIT License
+ *
+ * Copyright (c) 2020 Alexey Edelev <semlanik@gmail.com>
+ *
+ * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and
+ * to permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
+ * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+ * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ * DEALINGS IN THE SOFTWARE.
+ */
+
+package earlystop
+
+import (
+	"log"
+
+	neuralnetwork "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork"
+	training "git.semlanik.org/semlanik/NeuralNetwork/training"
+)
+
+type constantRateEarlyStop struct {
+	network     *neuralnetwork.NeuralNetwork
+	trainer     training.Trainer
+	minFailRate float64
+}
+
+func NewConstantRateEarlyStop(minFailRate float64, network *neuralnetwork.NeuralNetwork, trainer training.Trainer) (es *constantRateEarlyStop) {
+	es = nil
+	if network == nil || trainer == nil {
+		return
+	}
+
+	es = &constantRateEarlyStop{
+		network:     network,
+		trainer:     trainer,
+		minFailRate: minFailRate,
+	}
+	return
+}
+
+func (es *constantRateEarlyStop) Test() bool {
+	_, fails, total := es.network.Validate(es.trainer)
+	log.Printf("Fail count: %v/%v\n", fails, total)
+
+	failRate := float64(fails) / float64(total)
+	return es.minFailRate >= failRate
+}
+
+func (es *constantRateEarlyStop) Reset() {
+}

+ 87 - 0
earlystop/simpledescentearlystop.go

@@ -0,0 +1,87 @@
+/*
+ * MIT License
+ *
+ * Copyright (c) 2020 Alexey Edelev <semlanik@gmail.com>
+ *
+ * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and
+ * to permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
+ * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+ * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ * DEALINGS IN THE SOFTWARE.
+ */
+
+package earlystop
+
+import (
+	"log"
+
+	neuralnetwork "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork"
+	training "git.semlanik.org/semlanik/NeuralNetwork/training"
+)
+
+type simpleDescentEarlyStop struct {
+	lastFailRate     float64
+	bestFailRate     float64
+	failRateDeltaSum float64
+	network          *neuralnetwork.NeuralNetwork
+	trainer          training.Trainer
+}
+
+func NewSimpleDescentEarlyStop(network *neuralnetwork.NeuralNetwork, trainer training.Trainer) (es *simpleDescentEarlyStop) {
+	es = nil
+	if network == nil || trainer == nil {
+		return
+	}
+
+	es = &simpleDescentEarlyStop{
+		lastFailRate:     1.0,
+		bestFailRate:     1.0,
+		failRateDeltaSum: 0.0,
+		network:          network,
+		trainer:          trainer,
+	}
+	return
+}
+
+func (es *simpleDescentEarlyStop) Test() bool {
+	squareError, fails, total := es.network.Validate(es.trainer)
+	log.Printf("Fail count: %v/%v, error: %v\n", fails, total, squareError)
+
+	failRate := float64(fails) / float64(total)
+	failRateDelta := failRate - es.lastFailRate
+	log.Printf("failRate %v lastFailRate %v failRateDelta %v \n", failRate, es.lastFailRate, failRateDelta)
+
+	es.lastFailRate = failRate
+
+	if failRateDelta > 0 { //Positive failRateDelta cause fail rate grow, accumulate total grow
+		es.failRateDeltaSum += failRateDelta
+	} else { //Reset failRateDeltaSum in case if we step over one of local maximum
+		es.failRateDeltaSum = 0.0
+	}
+
+	if es.bestFailRate > es.lastFailRate {
+		es.bestFailRate = es.lastFailRate
+		//TODO: save neuralnetwork state at this point
+	}
+
+	return false //es.failRateDeltaSum > 0.05
+}
+
+func (es *simpleDescentEarlyStop) Reset() {
+	es.lastFailRate = 1.0
+	es.bestFailRate = 1.0
+	es.failRateDeltaSum = 0.0
+}

+ 16 - 0
neuralnetwork/interface.go

@@ -26,6 +26,7 @@
 package neuralnetwork
 
 import (
+	training "git.semlanik.org/semlanik/NeuralNetwork/training"
 	mat "gonum.org/v1/gonum/mat"
 )
 
@@ -74,3 +75,18 @@ func (f *SubscriptionFeatures) Unset(flag SubscriptionFeatures) {
 func (f *SubscriptionFeatures) Clear() {
 	*f = 0
 }
+
+type BatchWorker interface {
+	Run(trainer training.Trainer, startIndex, endIndex int)
+	Result(layer int) (dB, dW *mat.Dense)
+}
+
+type BatchWorkerFactory interface {
+	GetBatchWorker() BatchWorker
+	GetAvailableThreads() int
+}
+
+type EarlyStop interface {
+	Test() bool
+	Reset()
+}

+ 28 - 6
neuralnetwork/batchworker.go → neuralnetwork/localbatchworker.go

@@ -1,7 +1,7 @@
 /*
  * MIT License
  *
- * Copyright (c) 2019 Alexey Edelev <semlanik@gmail.com>
+ * Copyright (c) 2020 Alexey Edelev <semlanik@gmail.com>
  *
  * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork
  *
@@ -26,19 +26,33 @@
 package neuralnetwork
 
 import (
+	"runtime"
+
 	training "git.semlanik.org/semlanik/NeuralNetwork/training"
 	mat "gonum.org/v1/gonum/mat"
 )
 
-type batchWorker struct {
+type localBatchWorkerFactory struct {
+	network *NeuralNetwork
+}
+
+type localBatchWorker struct {
 	network   *NeuralNetwork
 	BGradient []BatchGradientDescent
 	WGradient []BatchGradientDescent
 	batchSize int
 }
 
-func newBatchWorker(nn *NeuralNetwork) (bw *batchWorker) {
-	bw = &batchWorker{
+func NewLocalBatchWorkerFactory(network *NeuralNetwork) BatchWorkerFactory {
+	factory := &localBatchWorkerFactory{
+		network: network,
+	}
+
+	return factory
+}
+
+func newLocalBatchWorker(nn *NeuralNetwork) (bw *localBatchWorker) {
+	bw = &localBatchWorker{
 		network:   nn,
 		BGradient: make([]BatchGradientDescent, nn.LayerCount),
 		WGradient: make([]BatchGradientDescent, nn.LayerCount),
@@ -51,7 +65,7 @@ func newBatchWorker(nn *NeuralNetwork) (bw *batchWorker) {
 	return
 }
 
-func (bw *batchWorker) run(trainer training.Trainer, startIndex, endIndex int) {
+func (bw *localBatchWorker) Run(trainer training.Trainer, startIndex, endIndex int) {
 	for i := startIndex; i < endIndex; i++ {
 		dB, dW := bw.network.backward(trainer.GetData(i))
 		for l := 1; l < bw.network.LayerCount; l++ {
@@ -61,6 +75,14 @@ func (bw *batchWorker) run(trainer training.Trainer, startIndex, endIndex int) {
 	}
 }
 
-func (bw *batchWorker) result(layer int) (dB, dW *mat.Dense) {
+func (bw *localBatchWorker) Result(layer int) (dB, dW *mat.Dense) {
 	return bw.BGradient[layer].Gradients(), bw.WGradient[layer].Gradients()
 }
+
+func (lbwf localBatchWorkerFactory) GetBatchWorker() BatchWorker {
+	return newLocalBatchWorker(lbwf.network)
+}
+
+func (lbwf localBatchWorkerFactory) GetAvailableThreads() int {
+	return runtime.NumCPU()
+}

+ 97 - 26
neuralnetwork/neuralnetwork.go

@@ -30,8 +30,9 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"log"
+	"math"
 	"os"
-	"runtime"
 	"sync"
 	"time"
 
@@ -109,6 +110,8 @@ type NeuralNetwork struct {
 	gradientDescentInitializer GradientDescentInitializer
 	watcher                    StateWatcher
 	syncMutex                  *sync.Mutex
+	batchWorkerFactory         BatchWorkerFactory
+	earlyStop                  EarlyStop
 }
 
 // NewNeuralNetwork construction method that initializes new NeuralNetwork based
@@ -141,6 +144,7 @@ func NewNeuralNetwork(sizes []int, gradientDescentInitializer GradientDescentIni
 		WGradient:                  make([]interface{}, lenSizes),
 		gradientDescentInitializer: gradientDescentInitializer,
 		syncMutex:                  &sync.Mutex{},
+		earlyStop:                  &noEarlyStop{},
 	}
 
 	for l := 1; l < nn.LayerCount; l++ {
@@ -170,6 +174,7 @@ func (nn *NeuralNetwork) Copy() (outNN *NeuralNetwork) {
 		gradientDescentInitializer: nn.gradientDescentInitializer,
 		watcher:                    nn.watcher,
 		syncMutex:                  &sync.Mutex{},
+		earlyStop:                  &noEarlyStop{},
 	}
 	for l := 1; l < outNN.LayerCount; l++ {
 		outNN.Biases[l] = mat.DenseCopyOf(nn.Biases[l])
@@ -182,8 +187,20 @@ func (nn *NeuralNetwork) Copy() (outNN *NeuralNetwork) {
 	return
 }
 
+// SetBatchWorkerFactory setup batch worker factory for batch training. In case if
+// factory is not setup localBatchWorkerFactory will be used.
+func (nn *NeuralNetwork) SetBatchWorkerFactory(factory BatchWorkerFactory) {
+	nn.batchWorkerFactory = factory
+}
+
+// SetEarlyStop setup early stop analyser to stop training before all training epocs finished.
+// Usually early stop required to avoid overfitting in neural network.
+func (nn *NeuralNetwork) SetEarlyStop(earlyStop EarlyStop) {
+	nn.earlyStop = earlyStop
+}
+
 // Reset resets network state to intial/random one with specified in argument
-// layers configuration
+// layers configuration.
 func (nn *NeuralNetwork) Reset(sizes []int) (err error) {
 	nn.syncMutex.Lock()
 	defer nn.syncMutex.Unlock()
@@ -222,7 +239,7 @@ func (nn *NeuralNetwork) Reset(sizes []int) (err error) {
 
 // SetStateWatcher setups state watcher for NeuralNetwork. StateWatcher is common
 // interface that collects data about NeuralNetwork behavior. If not specified (is
-// set to nil) NeuralNetwork will ignore StateWatcher interations
+// set to nil) NeuralNetwork will ignore StateWatcher interations.
 func (nn *NeuralNetwork) SetStateWatcher(watcher StateWatcher) {
 	nn.watcher = watcher
 	if watcher != nil {
@@ -234,7 +251,7 @@ func (nn *NeuralNetwork) SetStateWatcher(watcher StateWatcher) {
 }
 
 // Predict method invokes prediction based on input activations provided in argument.
-// Returns index of best element in output activation matrix and its value
+// Returns index of best element in output activation matrix and its value.
 func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
 	nn.syncMutex.Lock()
 	defer nn.syncMutex.Unlock()
@@ -266,20 +283,50 @@ func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
 
 // Validate runs basic network validation/verification based on validation data that
 // provided by training.Trainer passed as argument.
-// Returns count of failure predictions and total amount of verified samples.
-func (nn *NeuralNetwork) Validate(trainer training.Trainer) (failCount, total int) {
+// Returns count of failure predictions and total amount of verified samples and mean square sum of errors for all samples
+func (nn *NeuralNetwork) Validate(trainer training.Trainer) (squareError float64, failCount, total int) {
 	failCount = 0
+	squareError = 0.0
 	total = trainer.ValidatorCount()
+	nn.syncMutex.Lock()
+	defer nn.syncMutex.Unlock()
+	if nn.watcher != nil {
+		if nn.watcher.GetSubscriptionFeatures().Has(StateSubscription) {
+			nn.watcher.UpdateState(StateValidation)
+			defer nn.watcher.UpdateState(StateIdle)
+		}
+	}
 	for i := 0; i < trainer.ValidatorCount(); i++ {
-		dataSet, expect := trainer.GetValidator(i)
-		index, _ := nn.Predict(dataSet)
-		if expect.At(index, 0) != 1.0 {
+		aIn, aOut := trainer.GetValidator(i)
+		r, _ := aIn.Dims()
+		if r != nn.Sizes[0] {
+			fmt.Printf("Invalid rows number of input matrix size: %v\n", r)
+			return math.MaxFloat64, total, total
+		}
+
+		A, _ := nn.forward(aIn)
+		result := A[nn.LayerCount-1]
+		r, _ = result.Dims()
+
+		err := &mat.Dense{}
+		err.Sub(result, aOut)
+
+		var squareErrorLocal float64 = 0.0
+		max := 0.0
+		maxIndex := 0
+		for i := 0; i < r; i++ {
+			if result.At(i, 0) > max {
+				max = result.At(i, 0)
+				maxIndex = i
+			}
+			squareErrorLocal += err.At(i, 0) * err.At(i, 0)
+		}
+		if aOut.At(maxIndex, 0) != 1.0 {
 			failCount++
 		}
+		squareError += squareErrorLocal / float64(r)
 	}
 
-	nn.syncMutex.Lock()
-	defer nn.syncMutex.Unlock()
 	if nn.watcher != nil {
 		if nn.watcher.GetSubscriptionFeatures().Has(ValidationSubscription) {
 			nn.watcher.UpdateValidation(total, failCount)
@@ -290,7 +337,7 @@ func (nn *NeuralNetwork) Validate(trainer training.Trainer) (failCount, total in
 
 // Train is common training function that invokes one of training methods depends on
 // gradient descent used buy NeuralNetwork. training.Trainer passed as argument used
-// to get training data. Training loops are limited buy number of epocs
+// to get training data. Training loops are limited buy number of epocs.
 func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
 	if nn.watcher != nil {
 		if nn.watcher.GetSubscriptionFeatures().Has(StateSubscription) {
@@ -298,6 +345,11 @@ func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
 			defer nn.watcher.UpdateState(StateIdle)
 		}
 	}
+
+	if nn.earlyStop != nil {
+		nn.earlyStop.Reset()
+	}
+
 	if _, ok := nn.WGradient[nn.LayerCount-1].(OnlineGradientDescent); ok {
 		nn.trainOnline(trainer, epocs)
 	} else if _, ok := nn.WGradient[nn.LayerCount-1].(BatchGradientDescent); ok {
@@ -308,6 +360,7 @@ func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
 }
 
 func (nn *NeuralNetwork) trainOnline(trainer training.Trainer, epocs int) {
+
 	for t := 0; t < epocs; t++ {
 		for i := 0; i < trainer.DataCount(); i++ {
 			if nn.watcher != nil {
@@ -339,18 +392,22 @@ func (nn *NeuralNetwork) trainOnline(trainer training.Trainer, epocs int) {
 			}
 			nn.syncMutex.Unlock()
 		}
+
+		if nn.earlyStop != nil && nn.earlyStop.Test() {
+			log.Printf("Training stopped due to fail rate grow\n")
+			break
+		}
 	}
 }
 
 func (nn *NeuralNetwork) trainBatch(trainer training.Trainer, epocs int) {
-	fmt.Printf("Start training in %v threads\n", runtime.NumCPU())
 	for t := 0; t < epocs; t++ {
 		if nn.watcher != nil {
 			if nn.watcher.GetSubscriptionFeatures().Has(TrainingSubscription) {
 				nn.watcher.UpdateTraining(t, epocs, 0, trainer.DataCount())
 			}
 		}
-		batchWorkers := nn.runBatchWorkers(runtime.NumCPU(), trainer)
+		batchWorkers := nn.runBatchWorkers(trainer)
 		nn.syncMutex.Lock()
 		for l := 1; l < nn.LayerCount; l++ {
 			bGradient, ok := nn.BGradient[l].(BatchGradientDescent)
@@ -362,7 +419,7 @@ func (nn *NeuralNetwork) trainBatch(trainer training.Trainer, epocs int) {
 				panic("wGradient is not a BatchGradientDescent")
 			}
 			for _, bw := range batchWorkers {
-				dB, dW := bw.result(l)
+				dB, dW := bw.Result(l)
 				bGradient.AccumGradients(dB)
 				wGradient.AccumGradients(dW)
 			}
@@ -378,21 +435,35 @@ func (nn *NeuralNetwork) trainBatch(trainer training.Trainer, epocs int) {
 			}
 		}
 		nn.syncMutex.Unlock()
-		//TODO: remove this is not used for visualization
-		time.Sleep(100 * time.Millisecond)
+
+		if nn.earlyStop != nil && nn.earlyStop.Test() {
+			log.Printf("Training stopped due to fail rate grow\n")
+			break
+		}
+
+		if nn.watcher.GetSubscriptionFeatures().Has(BiasesSubscription) || nn.watcher.GetSubscriptionFeatures().Has(WeightsSubscription) {
+			time.Sleep(100 * time.Millisecond) //TODO: it's better to add 'Latency() int' method to watcher, for check above
+		}
 	}
 }
 
-func (nn *NeuralNetwork) runBatchWorkers(threadCount int, trainer training.Trainer) (workers []*batchWorker) {
+func (nn *NeuralNetwork) runBatchWorkers(trainer training.Trainer) (workers []BatchWorker) {
+	if nn.batchWorkerFactory == nil {
+		nn.batchWorkerFactory = NewLocalBatchWorkerFactory(nn)
+		log.Printf("Batch Worker factory is not set, using local one\n")
+	}
+
 	wg := sync.WaitGroup{}
+	threadCount := nn.batchWorkerFactory.GetAvailableThreads()
+
 	chunkSize := trainer.DataCount() / threadCount
-	workers = make([]*batchWorker, threadCount)
+	workers = make([]BatchWorker, threadCount)
 	for i, _ := range workers {
-		workers[i] = newBatchWorker(nn)
+		workers[i] = nn.batchWorkerFactory.GetBatchWorker()
 		wg.Add(1)
 		s := i
 		go func() {
-			workers[s].run(trainer, s*chunkSize, (s+1)*chunkSize)
+			workers[s].Run(trainer, s*chunkSize, (s+1)*chunkSize)
 			wg.Done()
 		}()
 	}
@@ -402,7 +473,7 @@ func (nn *NeuralNetwork) runBatchWorkers(threadCount int, trainer training.Train
 
 // SaveState saves state of NeuralNetwork to io.Writer. It's usefull to keep training results
 // between NeuralNetwork "power cycles" or to share traing results between clustered neural
-// network hosts
+// network hosts.
 func (nn *NeuralNetwork) SaveState(writer io.Writer) {
 	nn.syncMutex.Lock()
 	defer nn.syncMutex.Unlock()
@@ -435,7 +506,7 @@ func (nn *NeuralNetwork) SaveState(writer io.Writer) {
 	}
 }
 
-// SaveStateToFile saves NeuralNetwork state to file by specific filePath
+// SaveStateToFile saves NeuralNetwork state to file by specific filePath.
 func (nn *NeuralNetwork) SaveStateToFile(filePath string) {
 	outFile, err := os.OpenFile(filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
 	check(err)
@@ -444,7 +515,7 @@ func (nn *NeuralNetwork) SaveStateToFile(filePath string) {
 }
 
 // LoadState loads NeuralNetwork state from io.Reader. All existing data in NeuralNetwork
-// will be rewritten buy this method, including layers configuration and weights and biases
+// will be rewritten buy this method, including layers configuration and weights and biases.
 func (nn *NeuralNetwork) LoadState(reader io.Reader) {
 	nn.syncMutex.Lock()
 	defer nn.syncMutex.Unlock()
@@ -486,7 +557,7 @@ func (nn *NeuralNetwork) LoadState(reader io.Reader) {
 	// fmt.Printf("\nLoadState end\n")
 }
 
-// LoadStateFromFile loads NeuralNetwork state from file by specific filePath
+// LoadStateFromFile loads NeuralNetwork state from file by specific filePath.
 func (nn *NeuralNetwork) LoadStateFromFile(filePath string) {
 	inFile, err := os.Open(filePath)
 	check(err)
@@ -535,7 +606,7 @@ func (nn NeuralNetwork) forward(aIn mat.Matrix) (A, Z []*mat.Dense) {
 }
 
 // Function returns calculated bias and weights derivatives for each
-// layer arround aIn/aOut datasets
+// layer arround aIn/aOut datasets.
 func (nn NeuralNetwork) backward(aIn, aOut mat.Matrix) (dB, dW []*mat.Dense) {
 	A, Z := nn.forward(aIn)
 

+ 36 - 0
neuralnetwork/noearlystop.go

@@ -0,0 +1,36 @@
+/*
+ * MIT License
+ *
+ * Copyright (c) 2020 Alexey Edelev <semlanik@gmail.com>
+ *
+ * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and
+ * to permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
+ * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+ * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ * DEALINGS IN THE SOFTWARE.
+ */
+
+package neuralnetwork
+
+type noEarlyStop struct {
+}
+
+func (es *noEarlyStop) Test() bool {
+	return false
+}
+
+func (es *noEarlyStop) Reset() {
+}