|
@@ -221,13 +221,15 @@ func (nn *NeuralNetwork) Reset(sizes []int) (err error) {
|
|
}
|
|
}
|
|
|
|
|
|
// SetStateWatcher setups state watcher for NeuralNetwork. StateWatcher is common
|
|
// SetStateWatcher setups state watcher for NeuralNetwork. StateWatcher is common
|
|
-// interface that collects data about NeuralNetwork behaivor. If not specified (is
|
|
|
|
|
|
+// interface that collects data about NeuralNetwork behavior. If not specified (is
|
|
// set to nil) NeuralNetwork will ignore StateWatcher interations
|
|
// set to nil) NeuralNetwork will ignore StateWatcher interations
|
|
func (nn *NeuralNetwork) SetStateWatcher(watcher StateWatcher) {
|
|
func (nn *NeuralNetwork) SetStateWatcher(watcher StateWatcher) {
|
|
nn.watcher = watcher
|
|
nn.watcher = watcher
|
|
if watcher != nil {
|
|
if watcher != nil {
|
|
watcher.Init(nn)
|
|
watcher.Init(nn)
|
|
- watcher.UpdateState(StateIdle)
|
|
|
|
|
|
+ if nn.watcher.GetSubscriptionFeatures().Has(StateSubscription) {
|
|
|
|
+ watcher.UpdateState(StateIdle)
|
|
|
|
+ }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -237,8 +239,10 @@ func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
|
|
nn.syncMutex.Lock()
|
|
nn.syncMutex.Lock()
|
|
defer nn.syncMutex.Unlock()
|
|
defer nn.syncMutex.Unlock()
|
|
if nn.watcher != nil {
|
|
if nn.watcher != nil {
|
|
- nn.watcher.UpdateState(StatePredict)
|
|
|
|
- defer nn.watcher.UpdateState(StateIdle)
|
|
|
|
|
|
+ if nn.watcher.GetSubscriptionFeatures().Has(StateSubscription) {
|
|
|
|
+ nn.watcher.UpdateState(StatePredict)
|
|
|
|
+ defer nn.watcher.UpdateState(StateIdle)
|
|
|
|
+ }
|
|
}
|
|
}
|
|
r, _ := aIn.Dims()
|
|
r, _ := aIn.Dims()
|
|
if r != nn.Sizes[0] {
|
|
if r != nn.Sizes[0] {
|
|
@@ -267,14 +271,18 @@ func (nn *NeuralNetwork) Validate(trainer training.Trainer) (failCount, total in
|
|
nn.syncMutex.Lock()
|
|
nn.syncMutex.Lock()
|
|
defer nn.syncMutex.Unlock()
|
|
defer nn.syncMutex.Unlock()
|
|
failCount = 0
|
|
failCount = 0
|
|
- total = 0
|
|
|
|
|
|
+ total = trainer.ValidatorCount()
|
|
for i := 0; i < trainer.ValidatorCount(); i++ {
|
|
for i := 0; i < trainer.ValidatorCount(); i++ {
|
|
dataSet, expect := trainer.GetValidator(i)
|
|
dataSet, expect := trainer.GetValidator(i)
|
|
index, _ := nn.Predict(dataSet)
|
|
index, _ := nn.Predict(dataSet)
|
|
if expect.At(index, 0) != 1.0 {
|
|
if expect.At(index, 0) != 1.0 {
|
|
failCount++
|
|
failCount++
|
|
}
|
|
}
|
|
- total++
|
|
|
|
|
|
+ }
|
|
|
|
+ if nn.watcher != nil {
|
|
|
|
+ if nn.watcher.GetSubscriptionFeatures().Has(ValidationSubscription) {
|
|
|
|
+ nn.watcher.UpdateValidation(total, failCount)
|
|
|
|
+ }
|
|
}
|
|
}
|
|
return
|
|
return
|
|
}
|
|
}
|
|
@@ -284,8 +292,10 @@ func (nn *NeuralNetwork) Validate(trainer training.Trainer) (failCount, total in
|
|
// to get training data. Training loops are limited buy number of epocs
|
|
// to get training data. Training loops are limited buy number of epocs
|
|
func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
|
|
func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
|
|
if nn.watcher != nil {
|
|
if nn.watcher != nil {
|
|
- nn.watcher.UpdateState(StateLearning)
|
|
|
|
- defer nn.watcher.UpdateState(StateIdle)
|
|
|
|
|
|
+ if nn.watcher.GetSubscriptionFeatures().Has(StateSubscription) {
|
|
|
|
+ nn.watcher.UpdateState(StateLearning)
|
|
|
|
+ defer nn.watcher.UpdateState(StateIdle)
|
|
|
|
+ }
|
|
}
|
|
}
|
|
if _, ok := nn.WGradient[nn.LayerCount-1].(OnlineGradientDescent); ok {
|
|
if _, ok := nn.WGradient[nn.LayerCount-1].(OnlineGradientDescent); ok {
|
|
nn.trainOnline(trainer, epocs)
|
|
nn.trainOnline(trainer, epocs)
|
|
@@ -299,6 +309,11 @@ func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
|
|
func (nn *NeuralNetwork) trainOnline(trainer training.Trainer, epocs int) {
|
|
func (nn *NeuralNetwork) trainOnline(trainer training.Trainer, epocs int) {
|
|
for t := 0; t < epocs; t++ {
|
|
for t := 0; t < epocs; t++ {
|
|
for i := 0; i < trainer.DataCount(); i++ {
|
|
for i := 0; i < trainer.DataCount(); i++ {
|
|
|
|
+ if nn.watcher != nil {
|
|
|
|
+ if nn.watcher.GetSubscriptionFeatures().Has(TrainingSubscription) {
|
|
|
|
+ nn.watcher.UpdateTraining(t, epocs, i, trainer.DataCount())
|
|
|
|
+ }
|
|
|
|
+ }
|
|
nn.syncMutex.Lock()
|
|
nn.syncMutex.Lock()
|
|
dB, dW := nn.backward(trainer.GetData(i))
|
|
dB, dW := nn.backward(trainer.GetData(i))
|
|
for l := 1; l < nn.LayerCount; l++ {
|
|
for l := 1; l < nn.LayerCount; l++ {
|
|
@@ -313,8 +328,12 @@ func (nn *NeuralNetwork) trainOnline(trainer training.Trainer, epocs int) {
|
|
nn.Biases[l] = bGradient.ApplyDelta(nn.Biases[l], dB[l])
|
|
nn.Biases[l] = bGradient.ApplyDelta(nn.Biases[l], dB[l])
|
|
nn.Weights[l] = wGradient.ApplyDelta(nn.Weights[l], dW[l])
|
|
nn.Weights[l] = wGradient.ApplyDelta(nn.Weights[l], dW[l])
|
|
if nn.watcher != nil {
|
|
if nn.watcher != nil {
|
|
- nn.watcher.UpdateBiases(l, nn.Biases[l])
|
|
|
|
- nn.watcher.UpdateWeights(l, nn.Weights[l])
|
|
|
|
|
|
+ if nn.watcher.GetSubscriptionFeatures().Has(BiasesSubscription) {
|
|
|
|
+ nn.watcher.UpdateBiases(l, mat.DenseCopyOf(nn.Biases[l]))
|
|
|
|
+ }
|
|
|
|
+ if nn.watcher.GetSubscriptionFeatures().Has(WeightsSubscription) {
|
|
|
|
+ nn.watcher.UpdateWeights(l, mat.DenseCopyOf(nn.Weights[l]))
|
|
|
|
+ }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
nn.syncMutex.Unlock()
|
|
nn.syncMutex.Unlock()
|
|
@@ -325,6 +344,11 @@ func (nn *NeuralNetwork) trainOnline(trainer training.Trainer, epocs int) {
|
|
func (nn *NeuralNetwork) trainBatch(trainer training.Trainer, epocs int) {
|
|
func (nn *NeuralNetwork) trainBatch(trainer training.Trainer, epocs int) {
|
|
fmt.Printf("Start training in %v threads\n", runtime.NumCPU())
|
|
fmt.Printf("Start training in %v threads\n", runtime.NumCPU())
|
|
for t := 0; t < epocs; t++ {
|
|
for t := 0; t < epocs; t++ {
|
|
|
|
+ if nn.watcher != nil {
|
|
|
|
+ if nn.watcher.GetSubscriptionFeatures().Has(TrainingSubscription) {
|
|
|
|
+ nn.watcher.UpdateTraining(t, epocs, 0, trainer.DataCount())
|
|
|
|
+ }
|
|
|
|
+ }
|
|
batchWorkers := nn.runBatchWorkers(runtime.NumCPU(), trainer)
|
|
batchWorkers := nn.runBatchWorkers(runtime.NumCPU(), trainer)
|
|
nn.syncMutex.Lock()
|
|
nn.syncMutex.Lock()
|
|
for l := 1; l < nn.LayerCount; l++ {
|
|
for l := 1; l < nn.LayerCount; l++ {
|
|
@@ -344,8 +368,12 @@ func (nn *NeuralNetwork) trainBatch(trainer training.Trainer, epocs int) {
|
|
nn.Biases[l] = bGradient.ApplyDelta(nn.Biases[l])
|
|
nn.Biases[l] = bGradient.ApplyDelta(nn.Biases[l])
|
|
nn.Weights[l] = wGradient.ApplyDelta(nn.Weights[l])
|
|
nn.Weights[l] = wGradient.ApplyDelta(nn.Weights[l])
|
|
if nn.watcher != nil {
|
|
if nn.watcher != nil {
|
|
- nn.watcher.UpdateBiases(l, nn.Biases[l])
|
|
|
|
- nn.watcher.UpdateWeights(l, nn.Weights[l])
|
|
|
|
|
|
+ if nn.watcher.GetSubscriptionFeatures().Has(BiasesSubscription) {
|
|
|
|
+ nn.watcher.UpdateBiases(l, mat.DenseCopyOf(nn.Biases[l]))
|
|
|
|
+ }
|
|
|
|
+ if nn.watcher.GetSubscriptionFeatures().Has(WeightsSubscription) {
|
|
|
|
+ nn.watcher.UpdateWeights(l, mat.DenseCopyOf(nn.Weights[l]))
|
|
|
|
+ }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
nn.syncMutex.Unlock()
|
|
nn.syncMutex.Unlock()
|
|
@@ -472,7 +500,9 @@ func (nn NeuralNetwork) forward(aIn mat.Matrix) (A, Z []*mat.Dense) {
|
|
A[0] = mat.DenseCopyOf(aIn)
|
|
A[0] = mat.DenseCopyOf(aIn)
|
|
|
|
|
|
if nn.watcher != nil {
|
|
if nn.watcher != nil {
|
|
- nn.watcher.UpdateActivations(0, A[0])
|
|
|
|
|
|
+ if nn.watcher.GetSubscriptionFeatures().Has(ActivationsSubscription) {
|
|
|
|
+ nn.watcher.UpdateActivations(0, mat.DenseCopyOf(A[0]))
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
for l := 1; l < nn.LayerCount; l++ {
|
|
for l := 1; l < nn.LayerCount; l++ {
|
|
@@ -495,7 +525,9 @@ func (nn NeuralNetwork) forward(aIn mat.Matrix) (A, Z []*mat.Dense) {
|
|
// σ(W[l]*A[l−1]+B[l])
|
|
// σ(W[l]*A[l−1]+B[l])
|
|
aDst.Apply(applySigmoid, aDst)
|
|
aDst.Apply(applySigmoid, aDst)
|
|
if nn.watcher != nil {
|
|
if nn.watcher != nil {
|
|
- nn.watcher.UpdateActivations(l, aDst)
|
|
|
|
|
|
+ if nn.watcher.GetSubscriptionFeatures().Has(ActivationsSubscription) {
|
|
|
|
+ nn.watcher.UpdateActivations(l, mat.DenseCopyOf(aDst))
|
|
|
|
+ }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return
|
|
return
|