Tatyana Borisova пре 5 година
родитељ
комит
5564b2158f
2 измењених фајлова са 129 додато и 0 уклоњено
  1. 2 0
      neuralnetwork/main.go
  2. 127 0
      neuralnetwork/neuralnetworkbase/neuralnetwork.go

+ 2 - 0
neuralnetwork/main.go

@@ -32,6 +32,8 @@ func main() {
 	// 	fmt.Printf("A after:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
 	// }
 
+	//nn.SaveState("./data");
+	//nn.LoadState("./data");
 	failCount := 0
 	teacher.Reset()
 	for teacher.NextData() {

+ 127 - 0
neuralnetwork/neuralnetworkbase/neuralnetwork.go

@@ -26,8 +26,11 @@
 package neuralnetworkbase
 
 import (
+	"encoding/binary"
 	"errors"
 	"fmt"
+	"log"
+	"os"
 
 	teach "../teach"
 	mat "gonum.org/v1/gonum/mat"
@@ -195,12 +198,136 @@ func (nn *NeuralNetwork) Teach(teacher teach.Teacher) {
 	}
 }
 
+func check(e error) {
+	if e != nil {
+		panic(e)
+	}
+}
+
 func (nn *NeuralNetwork) SaveState(filename string) {
+	// Open file for reding
+	inputFile, err := os.Create(filename)
+	if err != nil {
+		log.Fatal(err)
+	}
+
+	defer inputFile.Close()
+
+	//save input array count
+	bufferSize := make([]byte, 4)
+	binary.LittleEndian.PutUint32(bufferSize[0:], uint32(nn.Count))
+	n2, err := inputFile.Write(bufferSize)
+
+	check(err)
+	fmt.Printf("wrote value %d\n", uint32(nn.Count))
+
+	// save an input array
+	buffer := make([]byte, nn.Count*4)
+	for i := 0; i < nn.Count; i++ {
+		binary.LittleEndian.PutUint32(buffer[i*4:], uint32(nn.Sizes[i]))
+	}
+
+	n2, err = inputFile.Write(buffer)
+	check(err)
+	fmt.Printf("wrote buffer %d bytes\n", n2)
+
+	//save biases
+	////////////////////////
+	for i := 1; i < nn.Count; i++ {
+		saveDense(inputFile, nn.Biases[i])
+	}
+
+	//save weights
+	////////////////////////
+	for i := 1; i < nn.Count; i++ {
+		saveDense(inputFile, nn.Weights[i])
+	}
+}
+
+func saveDense(inputFile *os.File, matrix *mat.Dense) {
+	buffer, _ := matrix.MarshalBinary()
+	//save int size of Biases buffer
+	bufferSize := make([]byte, 4)
+	binary.LittleEndian.PutUint32(bufferSize, uint32(len(buffer)))
+	inputFile.Write(bufferSize)
+	bufferCount, err := inputFile.Write(buffer)
+	check(err)
+	fmt.Printf("wrote array size %d count of bytes %d \n", len(buffer), bufferCount)
+	printMatDense(matrix)
+}
+
+func printMatDense(matrix *mat.Dense) {
+	// Print the result using the formatter.
+	fc := mat.Formatted(matrix, mat.Prefix("    "), mat.Squeeze())
+	fmt.Printf("c = %v \n\n", fc)
+}
+
+func readDense(inputFile *os.File, matrix *mat.Dense) *mat.Dense {
+	count := readInt(inputFile)
+	fmt.Printf("%d \n\n", count)
+	matrix = &mat.Dense{}
+	matrix.UnmarshalBinary(readByteArray(inputFile, count))
+	printMatDense(matrix)
+	return matrix
+}
+
+func readByteArray(inputFile *os.File, size int) []byte {
+	// Read an input array
+	sizeBuffer := make([]byte, size)
+	n1, err := inputFile.Read(sizeBuffer)
+	check(err)
+
+	fmt.Printf("readByteArray: size = %d \n", n1)
 
+	return sizeBuffer
+}
+
+func readInt(inputFile *os.File) int {
+	// Reade int
+	count := make([]byte, 4)
+	_, err := inputFile.Read(count)
+	check(err)
+
+	return int(binary.LittleEndian.Uint32(count))
 }
 
 func (nn *NeuralNetwork) LoadState(filename string) {
+	inputFile, err := os.Open(filename)
+	if err != nil {
+		log.Fatal(err)
+	}
+
+	defer inputFile.Close()
 
+	// Reade count
+	nn.Count = readInt(inputFile)
+
+	// Read an input array
+	sizeBuffer := readByteArray(inputFile, nn.Count*4)
+	nn.Sizes = make([]int, nn.Count)
+
+	for i := 0; i < nn.Count; i++ {
+		nn.Sizes[i] = int(binary.LittleEndian.Uint32(sizeBuffer[i*4:]))
+		fmt.Printf("LoadState: nn.Sizes[%d] %d \n", i, nn.Sizes[i])
+	}
+
+	nn.Weights = []*mat.Dense{&mat.Dense{}}
+	nn.Biases = []*mat.Dense{&mat.Dense{}}
+
+	// read Biases
+	nn.Biases[0] = &mat.Dense{}
+	for i := 1; i < nn.Count; i++ {
+		nn.Biases = append(nn.Biases, &mat.Dense{})
+		nn.Biases[i] = readDense(inputFile, nn.Biases[i])
+	}
+
+	// read Weights
+	nn.Weights[0] = &mat.Dense{}
+	for i := 1; i < nn.Count; i++ {
+		nn.Weights = append(nn.Weights, &mat.Dense{})
+		nn.Weights[i] = readDense(inputFile, nn.Weights[i])
+	}
+	fmt.Printf("\nLoadState end\n")
 }
 
 func (nn *NeuralNetwork) forward(aIn mat.Matrix) {