|
@@ -1,9 +1,6 @@
|
|
|
package neuralnetworkbase
|
|
|
|
|
|
import (
|
|
|
- rand "math/rand"
|
|
|
- "time"
|
|
|
-
|
|
|
mat "gonum.org/v1/gonum/mat"
|
|
|
)
|
|
|
|
|
@@ -56,10 +53,7 @@ func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
|
|
|
}
|
|
|
|
|
|
func (nn *NeuralNetwork) Train(dataSet, expect []*mat.Dense) {
|
|
|
- rand.Seed(time.Now().UnixNano())
|
|
|
dataSetSize := len(dataSet)
|
|
|
-
|
|
|
-
|
|
|
for i := 0; i < nn.trainingCycles; i++ {
|
|
|
for j := dataSetSize - 1; j >= 0; j -= 3 {
|
|
|
if j < 0 {
|
|
@@ -67,10 +61,6 @@ func (nn *NeuralNetwork) Train(dataSet, expect []*mat.Dense) {
|
|
|
}
|
|
|
nn.Backward(dataSet[j], expect[j])
|
|
|
}
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
}
|
|
|
}
|
|
|
|