Pārlūkot izejas kodu

Renaming some classes

Alexey Edelev 5 gadi atpakaļ
vecāks
revīzija
5643b42c39

+ 33 - 0
neuralnetwork/genetic/genetic.go

@@ -0,0 +1,33 @@
+package genetic
+
+import (
+	"log"
+
+	neuralnetwork "../neuralnetwork"
+)
+
+type Population struct {
+	Networks []*neuralnetwork.NeuralNetwork
+	verifier PopulationVerifier
+}
+
+func NewPopulation(verifier PopulationVerifier, populationSize int, sizes []int) (p *Population) {
+	p = &Population{
+		verifier: verifier,
+		Networks: make([]*neuralnetwork.NeuralNetwork, populationSize),
+	}
+
+	for i := 0; i < populationSize; i++ {
+		var err error
+		p.Networks[i], err = neuralnetwork.NewNeuralNetwork(sizes, nil)
+		if err != nil {
+			log.Fatal("Could not initialize NeuralNetwork")
+		}
+	}
+
+	return
+}
+
+func (p *Population) NaturalSelection(generationCount int) {
+
+}

+ 9 - 0
neuralnetwork/genetic/interface.go

@@ -0,0 +1,9 @@
+package genetic
+
+import (
+	"gonum.org/v1/gonum/mat"
+)
+
+type PopulationVerifier interface {
+	Verify(Population) *mat.Dense
+}

+ 4 - 4
neuralnetwork/main.go

@@ -44,13 +44,13 @@ func main() {
 	// inFile.Close()
 
 	// failCount = 0
-	// teacher.Reset()
-	// for teacher.NextValidator() {
-	// 	dataSet, expect := teacher.GetValidator()
+	// training.Reset()
+	// for training.NextValidator() {
+	// 	dataSet, expect := training.GetValidator()
 	// 	index, _ := nn.Predict(dataSet)
 	// 	if expect.At(index, 0) != 1.0 {
 	// 		failCount++
-	// 		// fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
+	// 		// fmt.Printf("Fail: %v, %v\n\n", training.ValidationIndex(), expect.At(index, 0))
 	// 	}
 	// }
 

+ 4 - 4
neuralnetwork/neuralnetwork/batchworker.go

@@ -26,7 +26,7 @@
 package neuralnetwork
 
 import (
-	teach "../teach"
+	training "../training"
 	mat "gonum.org/v1/gonum/mat"
 )
 
@@ -51,15 +51,15 @@ func newBatchWorker(nn *NeuralNetwork) (bw *batchWorker) {
 	return
 }
 
-func (bw *batchWorker) Run(teacher teach.Teacher, startIndex, endIndex int) {
+func (bw *batchWorker) Run(trainer training.Trainer, startIndex, endIndex int) {
 	for i := startIndex; i < endIndex; i++ {
-		dB, dW := bw.network.backward(teacher.GetDataByIndex(i))
+		dB, dW := bw.network.backward(trainer.GetDataByIndex(i))
 		for l := 1; l < bw.network.layerCount; l++ {
 			bw.BGradient[l].AccumGradients(dB[l])
 			bw.WGradient[l].AccumGradients(dW[l])
 		}
 	}
-	teacher.Reset()
+	trainer.Reset()
 }
 
 func (bw *batchWorker) Result(layer int) (dB, dW *mat.Dense) {

+ 13 - 13
neuralnetwork/neuralnetwork/neuralnetwork.go

@@ -34,7 +34,7 @@ import (
 	"sync"
 	"time"
 
-	teach "../teach"
+	training "../training"
 	mat "gonum.org/v1/gonum/mat"
 )
 
@@ -175,24 +175,24 @@ func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
 	return
 }
 
-func (nn *NeuralNetwork) Teach(teacher teach.Teacher, epocs int) {
+func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
 	if nn.watcher != nil {
 		nn.watcher.UpdateState(StateLearning)
 		defer nn.watcher.UpdateState(StateIdle)
 	}
 	if _, ok := nn.WGradient[nn.layerCount-1].(OnlineGradientDescent); ok {
-		nn.TeachOnline(teacher, epocs)
+		nn.TrainOnline(trainer, epocs)
 	} else if _, ok := nn.WGradient[nn.layerCount-1].(BatchGradientDescent); ok {
-		nn.TeachBatch(teacher, epocs)
+		nn.TrainBatch(trainer, epocs)
 	} else {
 		panic("Invalid gradient descent type")
 	}
 }
 
-func (nn *NeuralNetwork) TeachOnline(teacher teach.Teacher, epocs int) {
+func (nn *NeuralNetwork) TrainOnline(trainer training.Trainer, epocs int) {
 	for t := 0; t < epocs; t++ {
-		for teacher.NextData() {
-			dB, dW := nn.backward(teacher.GetData())
+		for trainer.NextData() {
+			dB, dW := nn.backward(trainer.GetData())
 			for l := 1; l < nn.layerCount; l++ {
 				bGradient, ok := nn.BGradient[l].(OnlineGradientDescent)
 				if !ok {
@@ -210,13 +210,13 @@ func (nn *NeuralNetwork) TeachOnline(teacher teach.Teacher, epocs int) {
 				}
 			}
 		}
-		teacher.Reset()
+		trainer.Reset()
 	}
 }
 
-func (nn *NeuralNetwork) TeachBatch(teacher teach.Teacher, epocs int) {
+func (nn *NeuralNetwork) TrainBatch(trainer training.Trainer, epocs int) {
 	for t := 0; t < epocs; t++ {
-		batchWorkers := nn.runBatchWorkers(runtime.NumCPU(), teacher)
+		batchWorkers := nn.runBatchWorkers(runtime.NumCPU(), trainer)
 
 		for l := 1; l < nn.layerCount; l++ {
 			bGradient, ok := nn.BGradient[l].(BatchGradientDescent)
@@ -244,16 +244,16 @@ func (nn *NeuralNetwork) TeachBatch(teacher teach.Teacher, epocs int) {
 	}
 }
 
-func (nn *NeuralNetwork) runBatchWorkers(threadCount int, teacher teach.Teacher) (workers []*batchWorker) {
+func (nn *NeuralNetwork) runBatchWorkers(threadCount int, trainer training.Trainer) (workers []*batchWorker) {
 	wg := sync.WaitGroup{}
-	chunkSize := teacher.GetDataCount() / threadCount
+	chunkSize := trainer.GetDataCount() / threadCount
 	workers = make([]*batchWorker, threadCount)
 	for i, _ := range workers {
 		workers[i] = newBatchWorker(nn)
 		wg.Add(1)
 		s := i
 		go func() {
-			workers[s].Run(teacher, s*chunkSize, (s+1)*chunkSize)
+			workers[s].Run(trainer, s*chunkSize, (s+1)*chunkSize)
 			wg.Done()
 		}()
 	}

+ 11 - 11
neuralnetwork/remotecontrol/remotecontrol.go

@@ -41,7 +41,7 @@ import (
 	"gonum.org/v1/gonum/mat"
 	grpc "google.golang.org/grpc"
 
-	teach "../teach"
+	training "../training"
 )
 
 type RemoteControl struct {
@@ -182,9 +182,9 @@ func (rw *RemoteControl) DummyStart(context.Context, *None) (*None, error) {
 	go func() {
 		rw.mutex.Lock()
 		defer rw.mutex.Unlock()
-		// teacher := teach.NewMNISTReader("./minst.data", "./mnist.labels")
-		teacher := teach.NewTextDataReader("wine.data", 5)
-		rw.nn.Teach(teacher, 500)
+		// trainer := training.NewMNISTReader("./minst.data", "./mnist.labels")
+		trainer := training.NewTextDataReader("wine.data", 5)
+		rw.nn.Train(trainer, 500)
 
 		// for i := 0; i < nn.Count; i++ {
 		// 	if i > 0 {
@@ -206,26 +206,26 @@ func (rw *RemoteControl) DummyStart(context.Context, *None) (*None, error) {
 		rw.UpdateState(neuralnetwork.StateLearning)
 		defer rw.UpdateState(neuralnetwork.StateIdle)
 		failCount := 0
-		teacher.Reset()
-		for teacher.NextValidator() {
-			dataSet, expect := teacher.GetValidator()
+		trainer.Reset()
+		for trainer.NextValidator() {
+			dataSet, expect := trainer.GetValidator()
 			index, _ := rw.nn.Predict(dataSet)
 			//TODO: remove this is not used for visualization
 			time.Sleep(400 * time.Millisecond)
 			if expect.At(index, 0) != 1.0 {
 				failCount++
-				// fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
+				// fmt.Printf("Fail: %v, %v\n\n", trainer.ValidationIndex(), expect.At(index, 0))
 			}
-			if !teacher.NextValidator() {
+			if !trainer.NextValidator() {
 				fmt.Printf("Fail count: %v\n\n", failCount)
 				failCount = 0
-				teacher.Reset()
+				trainer.Reset()
 			}
 		}
 
 		fmt.Printf("Fail count: %v\n\n", failCount)
 		failCount = 0
-		teacher.Reset()
+		trainer.Reset()
 		rw.UpdateState(neuralnetwork.StateIdle)
 	}()
 

+ 1 - 1
neuralnetwork/teach/mnistreader.go → neuralnetwork/training/mnistreader.go

@@ -23,7 +23,7 @@
  * DEALINGS IN THE SOFTWARE.
  */
 
-package teach
+package training
 
 import (
 	"encoding/binary"

+ 1 - 1
neuralnetwork/teach/textdatareader.go → neuralnetwork/training/textdatareader.go

@@ -23,7 +23,7 @@
  * DEALINGS IN THE SOFTWARE.
  */
 
-package teach
+package training
 
 import (
 	"bufio"

+ 2 - 2
neuralnetwork/teach/teacher.go → neuralnetwork/training/trainer.go

@@ -23,13 +23,13 @@
  * DEALINGS IN THE SOFTWARE.
  */
 
-package teach
+package training
 
 import (
 	mat "gonum.org/v1/gonum/mat"
 )
 
-type Teacher interface {
+type Trainer interface {
 	GetData() (*mat.Dense, *mat.Dense)
 	NextData() bool