Forráskód Böngészése

Add verification
Add specialized file load and save methods for neural networks
Minor code cleanup

Alexey Edelev 5 éve
szülő
commit
c390b17db0

+ 7 - 21
handwriting/handwriting/handwriting.go

@@ -29,9 +29,7 @@ import (
 	"bytes"
 	context "context"
 	fmt "fmt"
-	"log"
 	"net"
-	"os"
 
 	neuralnetwork "../../neuralnetwork/neuralnetwork"
 	training "../../neuralnetwork/training"
@@ -81,29 +79,17 @@ func (hws *HandwritingService) ReTrain(context.Context, *None) (*None, error) {
 	fmt.Println("ReTrain")
 
 	trainer := training.NewMNISTReader("./mnist.data", "./mnist.labels")
-	failCount := 0
-	total := 0
-	trainer.Reset()
-	for trainer.NextValidator() {
-		total++
-		dataSet, expect := trainer.GetValidator()
-		index, _ := hws.nn.Predict(dataSet)
-		if expect.At(index, 0) != 1.0 {
-			failCount++
-			// fmt.Printf("Fail: %v, %v\n\n", trainer.ValidationIndex(), expect.At(index, 0))
-		}
-	}
-	fmt.Printf("Fail count: %v/%v\n\n", failCount, total)
+	failCount, total := hws.nn.Validate(trainer)
+	fmt.Printf("Fail count before: %v/%v\n\n", failCount, total)
 
 	hws.nn.Train(trainer, 100)
 
+	hws.nn.SaveStateToFile("./mnistnet.nnd")
+
+	failCount, total = hws.nn.Validate(trainer)
+	fmt.Printf("Fail count after: %v/%v\n\n", failCount, total)
+
 	fmt.Println("ReTrain finished")
-	outFile, err := os.OpenFile("./mnistnet.nnd", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
-	if err != nil {
-		log.Fatal(err)
-	}
-	defer outFile.Close()
-	hws.nn.SaveState(outFile)
 	return &None{}, nil
 }
 

+ 47 - 0
neuralnetwork/neuralnetwork/gradient.go

@@ -0,0 +1,47 @@
+/*
+ * MIT License
+ *
+ * Copyright (c) 2019 Alexey Edelev <semlanik@gmail.com>
+ *
+ * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and
+ * to permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
+ * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+ * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ * DEALINGS IN THE SOFTWARE.
+ */
+
+package neuralnetwork
+
+import (
+	mat "gonum.org/v1/gonum/mat"
+)
+
+const (
+	BiasGradient   = iota
+	WeightGradient = iota
+)
+
+type GradientDescentInitializer func(nn *NeuralNetwork, layer, gradientType int) interface{}
+
+type OnlineGradientDescent interface {
+	ApplyDelta(m mat.Matrix, gradient mat.Matrix) *mat.Dense
+}
+
+type BatchGradientDescent interface {
+	ApplyDelta(m mat.Matrix) *mat.Dense
+	AccumGradients(gradient mat.Matrix)
+	Gradients() *mat.Dense
+}

+ 0 - 17
neuralnetwork/neuralnetwork/interface.go

@@ -29,23 +29,6 @@ import (
 	mat "gonum.org/v1/gonum/mat"
 )
 
-const (
-	BiasGradient   = iota
-	WeightGradient = iota
-)
-
-type GradientDescentInitializer func(nn *NeuralNetwork, layer, gradientType int) interface{}
-
-type OnlineGradientDescent interface {
-	ApplyDelta(m mat.Matrix, gradient mat.Matrix) *mat.Dense
-}
-
-type BatchGradientDescent interface {
-	ApplyDelta(m mat.Matrix) *mat.Dense
-	AccumGradients(gradient mat.Matrix)
-	Gradients() *mat.Dense
-}
-
 const (
 	StateIdle       = 1
 	StateLearning   = 2

+ 34 - 1
neuralnetwork/neuralnetwork/neuralnetwork.go

@@ -30,6 +30,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"os"
 	"runtime"
 	"sync"
 	"time"
@@ -107,7 +108,7 @@ type NeuralNetwork struct {
 	WGradient                  []interface{}
 	gradientDescentInitializer GradientDescentInitializer
 	watcher                    StateWatcher
-	syncMutex                  sync.Mutex
+	syncMutex                  *sync.Mutex
 }
 
 func NewNeuralNetwork(sizes []int, gradientDescentInitializer GradientDescentInitializer) (nn *NeuralNetwork, err error) {
@@ -133,6 +134,7 @@ func NewNeuralNetwork(sizes []int, gradientDescentInitializer GradientDescentIni
 		BGradient:                  make([]interface{}, lenSizes),
 		WGradient:                  make([]interface{}, lenSizes),
 		gradientDescentInitializer: gradientDescentInitializer,
+		syncMutex:                  &sync.Mutex{},
 	}
 
 	for l := 1; l < nn.LayerCount; l++ {
@@ -156,6 +158,7 @@ func (nn *NeuralNetwork) Copy() (outNN *NeuralNetwork) {
 		WGradient:                  make([]interface{}, nn.LayerCount),
 		gradientDescentInitializer: nn.gradientDescentInitializer,
 		watcher:                    nn.watcher,
+		syncMutex:                  &sync.Mutex{},
 	}
 	for l := 1; l < outNN.LayerCount; l++ {
 		outNN.Biases[l] = mat.DenseCopyOf(nn.Biases[l])
@@ -203,6 +206,22 @@ func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
 	return
 }
 
+func (nn *NeuralNetwork) Validate(trainer training.Trainer) (failCount, total int) {
+	failCount = 0
+	total = 0
+	trainer.Reset()
+	for trainer.NextValidator() {
+		dataSet, expect := trainer.GetValidator()
+		index, _ := nn.Predict(dataSet)
+		if expect.At(index, 0) != 1.0 {
+			failCount++
+		}
+		total++
+	}
+	trainer.Reset()
+	return
+}
+
 func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
 	if nn.watcher != nil {
 		nn.watcher.UpdateState(StateLearning)
@@ -325,6 +344,13 @@ func (nn *NeuralNetwork) SaveState(writer io.Writer) {
 	}
 }
 
+func (nn *NeuralNetwork) SaveStateToFile(filePath string) {
+	outFile, err := os.OpenFile(filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
+	check(err)
+	defer outFile.Close()
+	nn.SaveState(outFile)
+}
+
 func (nn *NeuralNetwork) LoadState(reader io.Reader) {
 	// Reade count
 	nn.LayerCount = readInt(reader)
@@ -364,6 +390,13 @@ func (nn *NeuralNetwork) LoadState(reader io.Reader) {
 	// fmt.Printf("\nLoadState end\n")
 }
 
+func (nn *NeuralNetwork) LoadStateFromFile(filePath string) {
+	inFile, err := os.Open(filePath)
+	check(err)
+	defer inFile.Close()
+	nn.LoadState(inFile)
+}
+
 func (nn NeuralNetwork) forward(aIn mat.Matrix) (A, Z []*mat.Dense) {
 	A = make([]*mat.Dense, nn.LayerCount)
 	Z = make([]*mat.Dense, nn.LayerCount)

+ 2 - 9
neuralnetwork/remotecontrol/remotecontrol.go

@@ -30,7 +30,6 @@ import (
 	fmt "fmt"
 	"log"
 	"net"
-	"os"
 	"sync"
 	"time"
 
@@ -199,13 +198,7 @@ func (rw *RemoteControl) DummyStart(context.Context, *None) (*None, error) {
 		// 	fmt.Printf("A after:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
 		// }
 
-		outFile, err := os.OpenFile("./data", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
-		if err != nil {
-			log.Fatal(err)
-		}
-		defer outFile.Close()
-		rw.nn.SaveState(outFile)
-		outFile.Close()
+		rw.nn.SaveStateToFile("./neuralnetworkdata.nnd")
 
 		rw.UpdateState(neuralnetwork.StateLearning)
 		defer rw.UpdateState(neuralnetwork.StateIdle)
@@ -214,7 +207,7 @@ func (rw *RemoteControl) DummyStart(context.Context, *None) (*None, error) {
 		for trainer.NextValidator() {
 			dataSet, expect := trainer.GetValidator()
 			index, _ := rw.nn.Predict(dataSet)
-			//TODO: remove this is not used for visualization
+			//TODO: remove this if not used for visualization
 			time.Sleep(400 * time.Millisecond)
 			if expect.At(index, 0) != 1.0 {
 				failCount++