Bladeren bron

Update Training

Alexey Edelev 5 jaren geleden
bovenliggende
commit
9fb7ebb94e
3 gewijzigde bestanden met toevoegingen van 44 en 21 verwijderingen
  1. 0 1
      iris.data
  2. 3 11
      neuralnetwork/main.go
  3. 41 9
      neuralnetwork/neuralnetworkbase/neuralnetwork.go

+ 0 - 1
iris.data

@@ -148,4 +148,3 @@
 6.5,3.0,5.2,2.0,Iris-virginica
 6.2,3.4,5.4,2.3,Iris-virginica
 5.9,3.0,5.1,1.8,Iris-virginica
-

+ 3 - 11
neuralnetwork/main.go

@@ -8,11 +8,8 @@ import (
 )
 
 func main() {
-
-	dataSet, result := readData("./iris.data")
-
 	sizes := []int{4, 8, 8, 3}
-	nn := neuralnetwork.NewNeuralNetwork(sizes)
+	nn := neuralnetwork.NewNeuralNetwork(sizes, 0.1, 481)
 
 	for i := 0; i < nn.Count; i++ {
 		if i > 0 {
@@ -23,13 +20,8 @@ func main() {
 		fmt.Printf("A before:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
 	}
 
-	for j := 0; j < 481; j++ {
-		for i := len(dataSet) - 1; i >= 0; i-- {
-			// 	fmt.Printf("Dataset[%d]:\n%v\n\n", i, mat.Formatted(dataSet[i], mat.Prefix(""), mat.Excerpt(0)))
-			// 	fmt.Printf("Result[%d]:\n%v\n\n", i, mat.Formatted(result[i], mat.Prefix(""), mat.Excerpt(0)))
-			nn.Backward(dataSet[i], result[i])
-		}
-	}
+	dataSet, result := readData("./iris.data")
+	nn.Train(dataSet, result)
 
 	for i := 0; i < nn.Count; i++ {
 		if i > 0 {

+ 41 - 9
neuralnetwork/neuralnetworkbase/neuralnetwork.go

@@ -1,24 +1,28 @@
 package neuralnetworkbase
 
 import (
+	rand "math/rand"
+	"time"
+
 	mat "gonum.org/v1/gonum/mat"
 )
 
 type NeuralNetwork struct {
-	Count   int
-	Sizes   []int
-	Biases  []*mat.Dense
-	Weights []*mat.Dense
-	A       []*mat.Dense
-	Z       []*mat.Dense
-	alpha   float64
+	Count          int
+	Sizes          []int
+	Biases         []*mat.Dense
+	Weights        []*mat.Dense
+	A              []*mat.Dense
+	Z              []*mat.Dense
+	alpha          float64
+	trainingCycles int
 }
 
 func (nn *NeuralNetwork) Result() *mat.Dense {
 	return nn.A[nn.Count-1]
 }
 
-func NewNeuralNetwork(Sizes []int) (nn *NeuralNetwork) {
+func NewNeuralNetwork(Sizes []int, nu float64, trainingCycles int) (nn *NeuralNetwork) {
 	nn = &NeuralNetwork{}
 	nn.Sizes = Sizes
 	nn.Count = len(Sizes)
@@ -26,7 +30,8 @@ func NewNeuralNetwork(Sizes []int) (nn *NeuralNetwork) {
 	nn.Biases = make([]*mat.Dense, nn.Count)
 	nn.A = make([]*mat.Dense, nn.Count)
 	nn.Z = make([]*mat.Dense, nn.Count)
-	nn.alpha = 0.1 / float64(nn.Sizes[0])
+	nn.alpha = nu / float64(nn.Sizes[0])
+	nn.trainingCycles = trainingCycles
 
 	for i := 1; i < nn.Count; i++ {
 		nn.Weights[i] = generateRandomDense(nn.Sizes[i], nn.Sizes[i-1])
@@ -50,6 +55,33 @@ func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
 	return
 }
 
+func (nn *NeuralNetwork) Train(dataSet, expect []*mat.Dense) {
+	rand.Seed(time.Now().UnixNano())
+	dataSetSize := len(dataSet)
+	// randomIndex := rand.Int() % dataSetSize
+	// fmt.Printf("Train: %v\n", randomIndex)
+	for i := 0; i < nn.trainingCycles; i++ {
+		for j := dataSetSize - 1; j >= 0; j -= 3 {
+			if j < 0 {
+				j = 0
+			}
+			nn.Backward(dataSet[j], expect[j])
+		}
+		// _, max := nn.Predict(dataSet[randomIndex])
+		// if 1.0-max < 0.2 {
+		// 	break
+		// }
+	}
+}
+
+func (nn *NeuralNetwork) SaveState(filename string) {
+
+}
+
+func (nn *NeuralNetwork) LoadState(filename string) {
+
+}
+
 func (nn *NeuralNetwork) Forward(aIn mat.Matrix) {
 	nn.A[0] = mat.DenseCopyOf(aIn)