Quellcode durchsuchen

Extend StateWatcher interface

- Add subscription features logic
- Add update information about training and validation process
Alexey Edelev vor 5 Jahren
Ursprung
Commit
bb2fbfe748
2 geänderte Dateien mit 77 neuen und 14 gelöschten Zeilen
  1. 31 0
      neuralnetwork/interface.go
  2. 46 14
      neuralnetwork/neuralnetwork.go

+ 31 - 0
neuralnetwork/interface.go

@@ -36,10 +36,41 @@ const (
 	StatePredict    = 4
 )
 
+type SubscriptionFeatures uint8
+
+const (
+	ActivationsSubscription SubscriptionFeatures = 1 << iota
+	BiasesSubscription
+	WeightsSubscription
+	TrainingSubscription
+	ValidationSubscription
+	StateSubscription
+	AllSubscription = 0xFF
+)
+
 type StateWatcher interface {
 	Init(nn *NeuralNetwork)
 	UpdateState(state int)
 	UpdateActivations(l int, a *mat.Dense)
 	UpdateBiases(l int, biases *mat.Dense)
 	UpdateWeights(l int, weights *mat.Dense)
+	UpdateTraining(t int, epocs int, samplesProcced int, totalSamplesCount int)
+	UpdateValidation(validatorCount int, failCount int)
+	GetSubscriptionFeatures() SubscriptionFeatures
+}
+
+func (f SubscriptionFeatures) Has(flag SubscriptionFeatures) bool {
+	return f&flag != 0
+}
+
+func (f *SubscriptionFeatures) Set(flag SubscriptionFeatures) {
+	*f |= flag
+}
+
+func (f *SubscriptionFeatures) Unset(flag SubscriptionFeatures) {
+	*f &= (^flag)
+}
+
+func (f *SubscriptionFeatures) Clear() {
+	*f = 0
 }

+ 46 - 14
neuralnetwork/neuralnetwork.go

@@ -221,13 +221,15 @@ func (nn *NeuralNetwork) Reset(sizes []int) (err error) {
 }
 
 // SetStateWatcher setups state watcher for NeuralNetwork. StateWatcher is common
-// interface that collects data about NeuralNetwork behaivor. If not specified (is
+// interface that collects data about NeuralNetwork behavior. If not specified (is
 // set to nil) NeuralNetwork will ignore StateWatcher interations
 func (nn *NeuralNetwork) SetStateWatcher(watcher StateWatcher) {
 	nn.watcher = watcher
 	if watcher != nil {
 		watcher.Init(nn)
-		watcher.UpdateState(StateIdle)
+		if nn.watcher.GetSubscriptionFeatures().Has(StateSubscription) {
+			watcher.UpdateState(StateIdle)
+		}
 	}
 }
 
@@ -237,8 +239,10 @@ func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
 	nn.syncMutex.Lock()
 	defer nn.syncMutex.Unlock()
 	if nn.watcher != nil {
-		nn.watcher.UpdateState(StatePredict)
-		defer nn.watcher.UpdateState(StateIdle)
+		if nn.watcher.GetSubscriptionFeatures().Has(StateSubscription) {
+			nn.watcher.UpdateState(StatePredict)
+			defer nn.watcher.UpdateState(StateIdle)
+		}
 	}
 	r, _ := aIn.Dims()
 	if r != nn.Sizes[0] {
@@ -267,14 +271,18 @@ func (nn *NeuralNetwork) Validate(trainer training.Trainer) (failCount, total in
 	nn.syncMutex.Lock()
 	defer nn.syncMutex.Unlock()
 	failCount = 0
-	total = 0
+	total = trainer.ValidatorCount()
 	for i := 0; i < trainer.ValidatorCount(); i++ {
 		dataSet, expect := trainer.GetValidator(i)
 		index, _ := nn.Predict(dataSet)
 		if expect.At(index, 0) != 1.0 {
 			failCount++
 		}
-		total++
+	}
+	if nn.watcher != nil {
+		if nn.watcher.GetSubscriptionFeatures().Has(ValidationSubscription) {
+			nn.watcher.UpdateValidation(total, failCount)
+		}
 	}
 	return
 }
@@ -284,8 +292,10 @@ func (nn *NeuralNetwork) Validate(trainer training.Trainer) (failCount, total in
 // to get training data. Training loops are limited buy number of epocs
 func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
 	if nn.watcher != nil {
-		nn.watcher.UpdateState(StateLearning)
-		defer nn.watcher.UpdateState(StateIdle)
+		if nn.watcher.GetSubscriptionFeatures().Has(StateSubscription) {
+			nn.watcher.UpdateState(StateLearning)
+			defer nn.watcher.UpdateState(StateIdle)
+		}
 	}
 	if _, ok := nn.WGradient[nn.LayerCount-1].(OnlineGradientDescent); ok {
 		nn.trainOnline(trainer, epocs)
@@ -299,6 +309,11 @@ func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
 func (nn *NeuralNetwork) trainOnline(trainer training.Trainer, epocs int) {
 	for t := 0; t < epocs; t++ {
 		for i := 0; i < trainer.DataCount(); i++ {
+			if nn.watcher != nil {
+				if nn.watcher.GetSubscriptionFeatures().Has(TrainingSubscription) {
+					nn.watcher.UpdateTraining(t, epocs, i, trainer.DataCount())
+				}
+			}
 			nn.syncMutex.Lock()
 			dB, dW := nn.backward(trainer.GetData(i))
 			for l := 1; l < nn.LayerCount; l++ {
@@ -313,8 +328,12 @@ func (nn *NeuralNetwork) trainOnline(trainer training.Trainer, epocs int) {
 				nn.Biases[l] = bGradient.ApplyDelta(nn.Biases[l], dB[l])
 				nn.Weights[l] = wGradient.ApplyDelta(nn.Weights[l], dW[l])
 				if nn.watcher != nil {
-					nn.watcher.UpdateBiases(l, nn.Biases[l])
-					nn.watcher.UpdateWeights(l, nn.Weights[l])
+					if nn.watcher.GetSubscriptionFeatures().Has(BiasesSubscription) {
+						nn.watcher.UpdateBiases(l, mat.DenseCopyOf(nn.Biases[l]))
+					}
+					if nn.watcher.GetSubscriptionFeatures().Has(WeightsSubscription) {
+						nn.watcher.UpdateWeights(l, mat.DenseCopyOf(nn.Weights[l]))
+					}
 				}
 			}
 			nn.syncMutex.Unlock()
@@ -325,6 +344,11 @@ func (nn *NeuralNetwork) trainOnline(trainer training.Trainer, epocs int) {
 func (nn *NeuralNetwork) trainBatch(trainer training.Trainer, epocs int) {
 	fmt.Printf("Start training in %v threads\n", runtime.NumCPU())
 	for t := 0; t < epocs; t++ {
+		if nn.watcher != nil {
+			if nn.watcher.GetSubscriptionFeatures().Has(TrainingSubscription) {
+				nn.watcher.UpdateTraining(t, epocs, 0, trainer.DataCount())
+			}
+		}
 		batchWorkers := nn.runBatchWorkers(runtime.NumCPU(), trainer)
 		nn.syncMutex.Lock()
 		for l := 1; l < nn.LayerCount; l++ {
@@ -344,8 +368,12 @@ func (nn *NeuralNetwork) trainBatch(trainer training.Trainer, epocs int) {
 			nn.Biases[l] = bGradient.ApplyDelta(nn.Biases[l])
 			nn.Weights[l] = wGradient.ApplyDelta(nn.Weights[l])
 			if nn.watcher != nil {
-				nn.watcher.UpdateBiases(l, nn.Biases[l])
-				nn.watcher.UpdateWeights(l, nn.Weights[l])
+				if nn.watcher.GetSubscriptionFeatures().Has(BiasesSubscription) {
+					nn.watcher.UpdateBiases(l, mat.DenseCopyOf(nn.Biases[l]))
+				}
+				if nn.watcher.GetSubscriptionFeatures().Has(WeightsSubscription) {
+					nn.watcher.UpdateWeights(l, mat.DenseCopyOf(nn.Weights[l]))
+				}
 			}
 		}
 		nn.syncMutex.Unlock()
@@ -472,7 +500,9 @@ func (nn NeuralNetwork) forward(aIn mat.Matrix) (A, Z []*mat.Dense) {
 	A[0] = mat.DenseCopyOf(aIn)
 
 	if nn.watcher != nil {
-		nn.watcher.UpdateActivations(0, A[0])
+		if nn.watcher.GetSubscriptionFeatures().Has(ActivationsSubscription) {
+			nn.watcher.UpdateActivations(0, mat.DenseCopyOf(A[0]))
+		}
 	}
 
 	for l := 1; l < nn.LayerCount; l++ {
@@ -495,7 +525,9 @@ func (nn NeuralNetwork) forward(aIn mat.Matrix) (A, Z []*mat.Dense) {
 		// σ(W[l]*A[l−1]+B[l])
 		aDst.Apply(applySigmoid, aDst)
 		if nn.watcher != nil {
-			nn.watcher.UpdateActivations(l, aDst)
+			if nn.watcher.GetSubscriptionFeatures().Has(ActivationsSubscription) {
+				nn.watcher.UpdateActivations(l, mat.DenseCopyOf(aDst))
+			}
 		}
 	}
 	return