Browse Source

Fix issue with random initialization

Alexey Edelev 5 years ago
parent
commit
c6b4c4d651

+ 2 - 0
neuralnetwork/neuralnetworkbase/common.go

@@ -3,11 +3,13 @@ package neuralnetworkbase
 import (
 	math "math"
 	rand "math/rand"
+	"time"
 
 	mat "gonum.org/v1/gonum/mat"
 )
 
 func generateRandomDense(rows, columns int) *mat.Dense {
+	rand.Seed(time.Now().UnixNano())
 	data := make([]float64, rows*columns)
 	for i := range data {
 		data[i] = rand.NormFloat64()

+ 0 - 10
neuralnetwork/neuralnetworkbase/neuralnetwork.go

@@ -1,9 +1,6 @@
 package neuralnetworkbase
 
 import (
-	rand "math/rand"
-	"time"
-
 	mat "gonum.org/v1/gonum/mat"
 )
 
@@ -56,10 +53,7 @@ func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
 }
 
 func (nn *NeuralNetwork) Train(dataSet, expect []*mat.Dense) {
-	rand.Seed(time.Now().UnixNano())
 	dataSetSize := len(dataSet)
-	// randomIndex := rand.Int() % dataSetSize
-	// fmt.Printf("Train: %v\n", randomIndex)
 	for i := 0; i < nn.trainingCycles; i++ {
 		for j := dataSetSize - 1; j >= 0; j -= 3 {
 			if j < 0 {
@@ -67,10 +61,6 @@ func (nn *NeuralNetwork) Train(dataSet, expect []*mat.Dense) {
 			}
 			nn.Backward(dataSet[j], expect[j])
 		}
-		// _, max := nn.Predict(dataSet[randomIndex])
-		// if 1.0-max < 0.2 {
-		// 	break
-		// }
 	}
 }