Ver código fonte

Add state watcher to output traing progress

Alexey Edelev 4 anos atrás
pai
commit
2d6e93bb81
1 arquivos alterados com 38 adições e 0 exclusões
  1. 38 0
      handwriting/handwriting.go

+ 38 - 0
handwriting/handwriting.go

@@ -50,6 +50,8 @@ func NewHandwritingService() (hws *HandwritingService) {
 		DeltaMax: 50.0,
 		DeltaMin: 0.000001,
 	}))
+
+	hws.nn.SetStateWatcher(hws)
 	return
 }
 
@@ -108,6 +110,42 @@ func (hws *HandwritingService) Run() {
 	}
 }
 
+func (hws *HandwritingService) Init(nn *neuralnetwork.NeuralNetwork) {
+
+}
+
+func (hws *HandwritingService) UpdateState(int) {
+
+}
+
+func (hws *HandwritingService) UpdateActivations(int, *mat.Dense) {
+
+}
+
+func (hws *HandwritingService) UpdateBiases(int, *mat.Dense) {
+
+}
+
+func (hws *HandwritingService) UpdateWeights(int, *mat.Dense) {
+
+}
+
+func (hws *HandwritingService) UpdateTraining(t int, epocs int, samplesProcced int, totalSamplesCount int) {
+	fmt.Printf("Training progress: Epoc: %v/%v\n", t, epocs)
+}
+
+func (hws *HandwritingService) UpdateValidation(validatorCount int, failCount int) {
+
+}
+
+func (hws *HandwritingService) GetSubscriptionFeatures() (features neuralnetwork.SubscriptionFeatures) {
+	features = 0
+	features.Set(neuralnetwork.TrainingSubscription)
+	features.Set(neuralnetwork.ValidationSubscription)
+
+	return
+}
+
 func drawImage(dense *mat.Dense) {
 	for i := 0; i < 28; i++ {
 		for j := 0; j < 28; j++ {