|
@@ -26,8 +26,11 @@
|
|
|
package neuralnetworkbase
|
|
|
|
|
|
import (
|
|
|
+ "encoding/binary"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
+ "log"
|
|
|
+ "os"
|
|
|
|
|
|
teach "../teach"
|
|
|
mat "gonum.org/v1/gonum/mat"
|
|
@@ -195,12 +198,136 @@ func (nn *NeuralNetwork) Teach(teacher teach.Teacher) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func check(e error) {
|
|
|
+ if e != nil {
|
|
|
+ panic(e)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func (nn *NeuralNetwork) SaveState(filename string) {
|
|
|
+ // Open file for reding
|
|
|
+ inputFile, err := os.Create(filename)
|
|
|
+ if err != nil {
|
|
|
+ log.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ defer inputFile.Close()
|
|
|
+
|
|
|
+ //save input array count
|
|
|
+ bufferSize := make([]byte, 4)
|
|
|
+ binary.LittleEndian.PutUint32(bufferSize[0:], uint32(nn.Count))
|
|
|
+ n2, err := inputFile.Write(bufferSize)
|
|
|
+
|
|
|
+ check(err)
|
|
|
+ fmt.Printf("wrote value %d\n", uint32(nn.Count))
|
|
|
+
|
|
|
+ // save an input array
|
|
|
+ buffer := make([]byte, nn.Count*4)
|
|
|
+ for i := 0; i < nn.Count; i++ {
|
|
|
+ binary.LittleEndian.PutUint32(buffer[i*4:], uint32(nn.Sizes[i]))
|
|
|
+ }
|
|
|
+
|
|
|
+ n2, err = inputFile.Write(buffer)
|
|
|
+ check(err)
|
|
|
+ fmt.Printf("wrote buffer %d bytes\n", n2)
|
|
|
+
|
|
|
+ //save biases
|
|
|
+ ////////////////////////
|
|
|
+ for i := 1; i < nn.Count; i++ {
|
|
|
+ saveDense(inputFile, nn.Biases[i])
|
|
|
+ }
|
|
|
+
|
|
|
+ //save weights
|
|
|
+ ////////////////////////
|
|
|
+ for i := 1; i < nn.Count; i++ {
|
|
|
+ saveDense(inputFile, nn.Weights[i])
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func saveDense(inputFile *os.File, matrix *mat.Dense) {
|
|
|
+ buffer, _ := matrix.MarshalBinary()
|
|
|
+ //save int size of Biases buffer
|
|
|
+ bufferSize := make([]byte, 4)
|
|
|
+ binary.LittleEndian.PutUint32(bufferSize, uint32(len(buffer)))
|
|
|
+ inputFile.Write(bufferSize)
|
|
|
+ bufferCount, err := inputFile.Write(buffer)
|
|
|
+ check(err)
|
|
|
+ fmt.Printf("wrote array size %d count of bytes %d \n", len(buffer), bufferCount)
|
|
|
+ printMatDense(matrix)
|
|
|
+}
|
|
|
+
|
|
|
+func printMatDense(matrix *mat.Dense) {
|
|
|
+ // Print the result using the formatter.
|
|
|
+ fc := mat.Formatted(matrix, mat.Prefix(" "), mat.Squeeze())
|
|
|
+ fmt.Printf("c = %v \n\n", fc)
|
|
|
+}
|
|
|
+
|
|
|
+func readDense(inputFile *os.File, matrix *mat.Dense) *mat.Dense {
|
|
|
+ count := readInt(inputFile)
|
|
|
+ fmt.Printf("%d \n\n", count)
|
|
|
+ matrix = &mat.Dense{}
|
|
|
+ matrix.UnmarshalBinary(readByteArray(inputFile, count))
|
|
|
+ printMatDense(matrix)
|
|
|
+ return matrix
|
|
|
+}
|
|
|
+
|
|
|
+func readByteArray(inputFile *os.File, size int) []byte {
|
|
|
+ // Read an input array
|
|
|
+ sizeBuffer := make([]byte, size)
|
|
|
+ n1, err := inputFile.Read(sizeBuffer)
|
|
|
+ check(err)
|
|
|
+
|
|
|
+ fmt.Printf("readByteArray: size = %d \n", n1)
|
|
|
|
|
|
+ return sizeBuffer
|
|
|
+}
|
|
|
+
|
|
|
+func readInt(inputFile *os.File) int {
|
|
|
+ // Reade int
|
|
|
+ count := make([]byte, 4)
|
|
|
+ _, err := inputFile.Read(count)
|
|
|
+ check(err)
|
|
|
+
|
|
|
+ return int(binary.LittleEndian.Uint32(count))
|
|
|
}
|
|
|
|
|
|
func (nn *NeuralNetwork) LoadState(filename string) {
|
|
|
+ inputFile, err := os.Open(filename)
|
|
|
+ if err != nil {
|
|
|
+ log.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ defer inputFile.Close()
|
|
|
|
|
|
+ // Reade count
|
|
|
+ nn.Count = readInt(inputFile)
|
|
|
+
|
|
|
+ // Read an input array
|
|
|
+ sizeBuffer := readByteArray(inputFile, nn.Count*4)
|
|
|
+ nn.Sizes = make([]int, nn.Count)
|
|
|
+
|
|
|
+ for i := 0; i < nn.Count; i++ {
|
|
|
+ nn.Sizes[i] = int(binary.LittleEndian.Uint32(sizeBuffer[i*4:]))
|
|
|
+ fmt.Printf("LoadState: nn.Sizes[%d] %d \n", i, nn.Sizes[i])
|
|
|
+ }
|
|
|
+
|
|
|
+ nn.Weights = []*mat.Dense{&mat.Dense{}}
|
|
|
+ nn.Biases = []*mat.Dense{&mat.Dense{}}
|
|
|
+
|
|
|
+ // read Biases
|
|
|
+ nn.Biases[0] = &mat.Dense{}
|
|
|
+ for i := 1; i < nn.Count; i++ {
|
|
|
+ nn.Biases = append(nn.Biases, &mat.Dense{})
|
|
|
+ nn.Biases[i] = readDense(inputFile, nn.Biases[i])
|
|
|
+ }
|
|
|
+
|
|
|
+ // read Weights
|
|
|
+ nn.Weights[0] = &mat.Dense{}
|
|
|
+ for i := 1; i < nn.Count; i++ {
|
|
|
+ nn.Weights = append(nn.Weights, &mat.Dense{})
|
|
|
+ nn.Weights[i] = readDense(inputFile, nn.Weights[i])
|
|
|
+ }
|
|
|
+ fmt.Printf("\nLoadState end\n")
|
|
|
}
|
|
|
|
|
|
func (nn *NeuralNetwork) forward(aIn mat.Matrix) {
|