|
@@ -34,7 +34,7 @@ import (
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
|
- teach "../teach"
|
|
|
+ training "../training"
|
|
|
mat "gonum.org/v1/gonum/mat"
|
|
|
)
|
|
|
|
|
@@ -175,24 +175,24 @@ func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func (nn *NeuralNetwork) Teach(teacher teach.Teacher, epocs int) {
|
|
|
+func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
|
|
|
if nn.watcher != nil {
|
|
|
nn.watcher.UpdateState(StateLearning)
|
|
|
defer nn.watcher.UpdateState(StateIdle)
|
|
|
}
|
|
|
if _, ok := nn.WGradient[nn.layerCount-1].(OnlineGradientDescent); ok {
|
|
|
- nn.TeachOnline(teacher, epocs)
|
|
|
+ nn.TrainOnline(trainer, epocs)
|
|
|
} else if _, ok := nn.WGradient[nn.layerCount-1].(BatchGradientDescent); ok {
|
|
|
- nn.TeachBatch(teacher, epocs)
|
|
|
+ nn.TrainBatch(trainer, epocs)
|
|
|
} else {
|
|
|
panic("Invalid gradient descent type")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (nn *NeuralNetwork) TeachOnline(teacher teach.Teacher, epocs int) {
|
|
|
+func (nn *NeuralNetwork) TrainOnline(trainer training.Trainer, epocs int) {
|
|
|
for t := 0; t < epocs; t++ {
|
|
|
- for teacher.NextData() {
|
|
|
- dB, dW := nn.backward(teacher.GetData())
|
|
|
+ for trainer.NextData() {
|
|
|
+ dB, dW := nn.backward(trainer.GetData())
|
|
|
for l := 1; l < nn.layerCount; l++ {
|
|
|
bGradient, ok := nn.BGradient[l].(OnlineGradientDescent)
|
|
|
if !ok {
|
|
@@ -210,13 +210,13 @@ func (nn *NeuralNetwork) TeachOnline(teacher teach.Teacher, epocs int) {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- teacher.Reset()
|
|
|
+ trainer.Reset()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (nn *NeuralNetwork) TeachBatch(teacher teach.Teacher, epocs int) {
|
|
|
+func (nn *NeuralNetwork) TrainBatch(trainer training.Trainer, epocs int) {
|
|
|
for t := 0; t < epocs; t++ {
|
|
|
- batchWorkers := nn.runBatchWorkers(runtime.NumCPU(), teacher)
|
|
|
+ batchWorkers := nn.runBatchWorkers(runtime.NumCPU(), trainer)
|
|
|
|
|
|
for l := 1; l < nn.layerCount; l++ {
|
|
|
bGradient, ok := nn.BGradient[l].(BatchGradientDescent)
|
|
@@ -244,16 +244,16 @@ func (nn *NeuralNetwork) TeachBatch(teacher teach.Teacher, epocs int) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (nn *NeuralNetwork) runBatchWorkers(threadCount int, teacher teach.Teacher) (workers []*batchWorker) {
|
|
|
+func (nn *NeuralNetwork) runBatchWorkers(threadCount int, trainer training.Trainer) (workers []*batchWorker) {
|
|
|
wg := sync.WaitGroup{}
|
|
|
- chunkSize := teacher.GetDataCount() / threadCount
|
|
|
+ chunkSize := trainer.GetDataCount() / threadCount
|
|
|
workers = make([]*batchWorker, threadCount)
|
|
|
for i, _ := range workers {
|
|
|
workers[i] = newBatchWorker(nn)
|
|
|
wg.Add(1)
|
|
|
s := i
|
|
|
go func() {
|
|
|
- workers[s].Run(teacher, s*chunkSize, (s+1)*chunkSize)
|
|
|
+ workers[s].Run(trainer, s*chunkSize, (s+1)*chunkSize)
|
|
|
wg.Done()
|
|
|
}()
|
|
|
}
|