|
@@ -30,6 +30,7 @@ import (
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
+ "os"
|
|
|
"runtime"
|
|
|
"sync"
|
|
|
"time"
|
|
@@ -107,7 +108,7 @@ type NeuralNetwork struct {
|
|
|
WGradient []interface{}
|
|
|
gradientDescentInitializer GradientDescentInitializer
|
|
|
watcher StateWatcher
|
|
|
- syncMutex sync.Mutex
|
|
|
+ syncMutex *sync.Mutex
|
|
|
}
|
|
|
|
|
|
func NewNeuralNetwork(sizes []int, gradientDescentInitializer GradientDescentInitializer) (nn *NeuralNetwork, err error) {
|
|
@@ -133,6 +134,7 @@ func NewNeuralNetwork(sizes []int, gradientDescentInitializer GradientDescentIni
|
|
|
BGradient: make([]interface{}, lenSizes),
|
|
|
WGradient: make([]interface{}, lenSizes),
|
|
|
gradientDescentInitializer: gradientDescentInitializer,
|
|
|
+ syncMutex: &sync.Mutex{},
|
|
|
}
|
|
|
|
|
|
for l := 1; l < nn.LayerCount; l++ {
|
|
@@ -156,6 +158,7 @@ func (nn *NeuralNetwork) Copy() (outNN *NeuralNetwork) {
|
|
|
WGradient: make([]interface{}, nn.LayerCount),
|
|
|
gradientDescentInitializer: nn.gradientDescentInitializer,
|
|
|
watcher: nn.watcher,
|
|
|
+ syncMutex: &sync.Mutex{},
|
|
|
}
|
|
|
for l := 1; l < outNN.LayerCount; l++ {
|
|
|
outNN.Biases[l] = mat.DenseCopyOf(nn.Biases[l])
|
|
@@ -203,6 +206,22 @@ func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+func (nn *NeuralNetwork) Validate(trainer training.Trainer) (failCount, total int) {
|
|
|
+ failCount = 0
|
|
|
+ total = 0
|
|
|
+ trainer.Reset()
|
|
|
+ for trainer.NextValidator() {
|
|
|
+ dataSet, expect := trainer.GetValidator()
|
|
|
+ index, _ := nn.Predict(dataSet)
|
|
|
+ if expect.At(index, 0) != 1.0 {
|
|
|
+ failCount++
|
|
|
+ }
|
|
|
+ total++
|
|
|
+ }
|
|
|
+ trainer.Reset()
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
|
|
|
if nn.watcher != nil {
|
|
|
nn.watcher.UpdateState(StateLearning)
|
|
@@ -325,6 +344,13 @@ func (nn *NeuralNetwork) SaveState(writer io.Writer) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func (nn *NeuralNetwork) SaveStateToFile(filePath string) {
|
|
|
+ outFile, err := os.OpenFile(filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
|
|
|
+ check(err)
|
|
|
+ defer outFile.Close()
|
|
|
+ nn.SaveState(outFile)
|
|
|
+}
|
|
|
+
|
|
|
func (nn *NeuralNetwork) LoadState(reader io.Reader) {
|
|
|
// Reade count
|
|
|
nn.LayerCount = readInt(reader)
|
|
@@ -364,6 +390,13 @@ func (nn *NeuralNetwork) LoadState(reader io.Reader) {
|
|
|
// fmt.Printf("\nLoadState end\n")
|
|
|
}
|
|
|
|
|
|
+func (nn *NeuralNetwork) LoadStateFromFile(filePath string) {
|
|
|
+ inFile, err := os.Open(filePath)
|
|
|
+ check(err)
|
|
|
+ defer inFile.Close()
|
|
|
+ nn.LoadState(inFile)
|
|
|
+}
|
|
|
+
|
|
|
func (nn NeuralNetwork) forward(aIn mat.Matrix) (A, Z []*mat.Dense) {
|
|
|
A = make([]*mat.Dense, nn.LayerCount)
|
|
|
Z = make([]*mat.Dense, nn.LayerCount)
|