ソースを参照

Add teacher interface

- Add teacher interface
- Refine data reader to match teacher interface
Alexey Edelev 5 年 前
コミット
fb3a074f14

+ 0 - 72
neuralnetwork/datareader.go

@@ -1,72 +0,0 @@
-package main
-
-import (
-	"bufio"
-	"fmt"
-	"log"
-	"os"
-	"strconv"
-	"strings"
-
-	mat "gonum.org/v1/gonum/mat"
-)
-
-func readData(filename string) (dataSet, result []*mat.Dense) {
-	inputFile, err := os.Open(filename)
-	if err != nil {
-		log.Fatal(err)
-	}
-
-	defer inputFile.Close()
-
-	scanner := bufio.NewScanner(inputFile)
-	scanner.Split(bufio.ScanLines)
-
-	max := []float64{0.0, 0.0, 0.0, 0.0}
-	for scanner.Scan() {
-		dataLine := scanner.Text()
-		data := strings.Split(dataLine, ",")
-
-		if len(data) < 5 {
-			fmt.Printf("Garbage record: %s\n", dataLine)
-			continue
-		}
-
-		var dataRaw []float64
-		for i := 0; i < 4; i++ {
-			val, err := strconv.ParseFloat(data[i], 64)
-			if err != nil {
-				break
-			}
-			dataRaw = append(dataRaw, val)
-
-			if max[i] < val {
-				max[i] = val
-			}
-		}
-
-		if len(dataRaw) < 4 {
-			fmt.Printf("Garbage record: %s\n", dataLine)
-			continue
-		}
-		dataSet = append(dataSet, mat.NewDense(4, 1, dataRaw))
-
-		switch data[4] {
-		case "Iris-setosa":
-			result = append(result, mat.NewDense(3, 1, []float64{1.0, 0.0, 0.0}))
-		case "Iris-versicolor":
-			result = append(result, mat.NewDense(3, 1, []float64{0.0, 1.0, 0.0}))
-		case "Iris-virginica":
-			result = append(result, mat.NewDense(3, 1, []float64{0.0, 0.0, 1.0}))
-		}
-	}
-
-	//normalize
-	for i := 0; i < len(dataSet); i++ {
-		dataSet[i].Apply(func(r, _ int, val float64) float64 {
-			return val / max[r]
-		}, dataSet[i])
-	}
-
-	return
-}

+ 9 - 6
neuralnetwork/main.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 
 	neuralnetwork "./neuralnetworkbase"
+	teach "./teach"
 )
 
 func main() {
@@ -19,8 +20,8 @@ func main() {
 	// 	fmt.Printf("A before:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
 	// }
 
-	dataSet, result := readData("./iris.data")
-	nn.Train(dataSet, result)
+	teacher := teach.NewTextDataReader("./iris.data")
+	nn.Teach(teacher)
 
 	// for i := 0; i < nn.Count; i++ {
 	// 	if i > 0 {
@@ -32,11 +33,13 @@ func main() {
 	// }
 
 	failCount := 0
-	for i := 0; i < len(dataSet); i++ {
-		index, _ := nn.Predict(dataSet[i])
-		if result[i].At(index, 0) != 1.0 {
+	teacher.Reset()
+	for teacher.Next() {
+		index, _ := nn.Predict(teacher.GetData())
+		expect := teacher.GetExpect()
+		if expect.At(index, 0) != 1.0 {
 			failCount++
-			fmt.Printf("Fail: %v, %v\n\n", i, result[i].At(index, 0))
+			fmt.Printf("Fail: %v, %v\n\n", teacher.Index(), expect.At(index, 0))
 		}
 	}
 

+ 1 - 1
neuralnetwork/neuralnetworkbase/common.go

@@ -1,7 +1,7 @@
 package neuralnetworkbase
 
 import (
-	math "math"
+	"math"
 	rand "math/rand"
 	"time"
 

+ 4 - 7
neuralnetwork/neuralnetworkbase/neuralnetwork.go

@@ -1,6 +1,7 @@
 package neuralnetworkbase
 
 import (
+	teach "../teach"
 	mat "gonum.org/v1/gonum/mat"
 )
 
@@ -48,14 +49,10 @@ func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
 	return
 }
 
-func (nn *NeuralNetwork) Train(dataSet, expect []*mat.Dense) {
-	dataSetSize := len(dataSet)
+func (nn *NeuralNetwork) Teach(teacher teach.Teacher) {
 	for i := 0; i < nn.trainingCycles; i++ {
-		for j := dataSetSize - 1; j >= 0; j -= 3 {
-			if j < 0 {
-				j = 0
-			}
-			nn.backward(dataSet[j], expect[j])
+		for teacher.Next() {
+			nn.backward(teacher.GetData(), teacher.GetExpect())
 		}
 	}
 }

+ 12 - 0
neuralnetwork/teach/teacher.go

@@ -0,0 +1,12 @@
+package teach
+
+import (
+	mat "gonum.org/v1/gonum/mat"
+)
+
+type Teacher interface {
+	GetData() *mat.Dense
+	GetExpect() *mat.Dense
+	Next() bool
+	Reset()
+}

+ 110 - 0
neuralnetwork/teach/textdatareader.go

@@ -0,0 +1,110 @@
+package teach
+
+import (
+	"bufio"
+	"fmt"
+	"log"
+	"os"
+	"strconv"
+	"strings"
+
+	mat "gonum.org/v1/gonum/mat"
+)
+
+type TextDataReader struct {
+	dataSet []*mat.Dense
+	result  []*mat.Dense
+	index   int
+}
+
+func NewTextDataReader(filename string) *TextDataReader {
+	r := &TextDataReader{
+		index: 0,
+	}
+	r.readData(filename)
+
+	return r
+}
+func (r *TextDataReader) readData(filename string) {
+	inputFile, err := os.Open(filename)
+	if err != nil {
+		log.Fatal(err)
+	}
+
+	defer inputFile.Close()
+
+	scanner := bufio.NewScanner(inputFile)
+	scanner.Split(bufio.ScanLines)
+
+	max := []float64{0.0, 0.0, 0.0, 0.0}
+	for scanner.Scan() {
+		dataLine := scanner.Text()
+		data := strings.Split(dataLine, ",")
+
+		if len(data) < 5 {
+			fmt.Printf("Garbage record: %s\n", dataLine)
+			continue
+		}
+
+		var dataRaw []float64
+		for i := 0; i < 4; i++ {
+			val, err := strconv.ParseFloat(data[i], 64)
+			if err != nil {
+				break
+			}
+			dataRaw = append(dataRaw, val)
+
+			if max[i] < val {
+				max[i] = val
+			}
+		}
+
+		if len(dataRaw) < 4 {
+			fmt.Printf("Garbage record: %s\n", dataLine)
+			continue
+		}
+		r.dataSet = append(r.dataSet, mat.NewDense(4, 1, dataRaw))
+
+		switch data[4] {
+		case "Iris-setosa":
+			r.result = append(r.result, mat.NewDense(3, 1, []float64{1.0, 0.0, 0.0}))
+		case "Iris-versicolor":
+			r.result = append(r.result, mat.NewDense(3, 1, []float64{0.0, 1.0, 0.0}))
+		case "Iris-virginica":
+			r.result = append(r.result, mat.NewDense(3, 1, []float64{0.0, 0.0, 1.0}))
+		}
+	}
+
+	//normalize
+	for i := 0; i < len(r.dataSet); i++ {
+		r.dataSet[i].Apply(func(r, _ int, val float64) float64 {
+			return val / max[r]
+		}, r.dataSet[i])
+	}
+}
+
+func (r *TextDataReader) GetData() *mat.Dense {
+	return r.dataSet[r.index]
+}
+
+func (r *TextDataReader) GetExpect() *mat.Dense {
+	return r.result[r.index]
+}
+
+func (r *TextDataReader) Next() bool {
+	r.index++
+	if r.index >= len(r.result) {
+		r.index = 0
+		return false
+	}
+
+	return true
+}
+
+func (r *TextDataReader) Reset() {
+	r.index = 0
+}
+
+func (r *TextDataReader) Index() int {
+	return r.index
+}