ソースを参照

Add mnist reader

Alexey Edelev 5 年 前
コミット
9b0a8d7708

+ 2 - 2
neuralnetwork/neuralnetworkbase/neuralnetwork.go

@@ -41,7 +41,7 @@ import (
 //
 // Matrix: A
 // Description: A is set of calculated neuron activations after sigmoid correction
-// Format:    0            n          N
+// Format:    0          n           N
 //         ⎡A[0] ⎤ ... ⎡A[0] ⎤ ... ⎡A[0] ⎤
 //         ⎢A[1] ⎥ ... ⎢A[1] ⎥ ... ⎢A[1] ⎥
 //         ⎢ ... ⎥ ... ⎢ ... ⎥ ... ⎢ ... ⎥
@@ -52,7 +52,7 @@ import (
 //
 // Matrix: Z
 // Description: Z is set of calculated raw neuron activations
-// Format:    0            n          N
+// Format:    0          n           N
 //         ⎡Z[0] ⎤ ... ⎡Z[0] ⎤ ... ⎡Z[0] ⎤
 //         ⎢Z[1] ⎥ ... ⎢Z[1] ⎥ ... ⎢Z[1] ⎥
 //         ⎢ ... ⎥ ... ⎢ ... ⎥ ... ⎢ ... ⎥

+ 142 - 0
neuralnetwork/teach/mnistreader.go

@@ -0,0 +1,142 @@
+/*
+ * MIT License
+ *
+ * Copyright (c) 2019 Alexey Edelev <semlanik@gmail.com>
+ *
+ * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and
+ * to permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
+ * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+ * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ * DEALINGS IN THE SOFTWARE.
+ */
+
+package teach
+
+import (
+	"encoding/binary"
+	"fmt"
+	"io"
+	"log"
+	"os"
+
+	mat "gonum.org/v1/gonum/mat"
+)
+
+type MNISTReader struct {
+	file            *os.File
+	resultsFile     *os.File
+	size            int
+	imageSize       int
+	buffered        *mat.Dense
+	resultsBuffered *mat.Dense
+}
+
+func NewMNISTReader(dataFilename string, resultsFilename string) (r *MNISTReader) {
+	r = &MNISTReader{}
+
+	var err error
+	r.file, err = os.Open(dataFilename)
+	if err != nil {
+		return nil
+	}
+
+	r.resultsFile, err = os.Open(resultsFilename)
+	if err != nil {
+		return nil
+	}
+
+	buffer := make([]byte, 16)
+	r.file.Read(buffer)
+	header := binary.BigEndian.Uint32(buffer[:4])
+	if header != 0x00000803 {
+		return nil
+	}
+	r.size = int(binary.BigEndian.Uint32(buffer[4:8]))
+	r.imageSize = int(binary.BigEndian.Uint32(buffer[8:12])) * int(binary.BigEndian.Uint32(buffer[12:16]))
+	fmt.Printf("Image size: %v\n", r.imageSize)
+	buffer = make([]byte, 8)
+	r.resultsFile.Read(buffer)
+	header = binary.BigEndian.Uint32(buffer[0:4])
+	if header != 0x00000801 {
+		return nil
+	}
+	resultsSize := int(binary.BigEndian.Uint32(buffer[4:8]))
+	if resultsSize != r.size {
+		return nil
+	}
+
+	return
+}
+
+func (r *MNISTReader) GetData() *mat.Dense {
+	return r.buffered
+}
+
+func (r *MNISTReader) GetExpect() *mat.Dense {
+	return r.resultsBuffered
+}
+
+func (r *MNISTReader) Next() bool {
+	buffer := make([]byte, r.imageSize)
+	_, err := r.file.Read(buffer)
+
+	if err == io.EOF {
+		r.file.Seek(16, 0)
+		r.resultsFile.Seek(8, 0)
+		return false
+	} else if err != nil {
+		log.Fatal("File read error\n")
+	}
+
+	values := make([]float64, r.imageSize)
+	for i, v := range buffer {
+		values[i] = float64(v) / 255.0
+	}
+
+	r.buffered = mat.NewDense(r.imageSize, 1, values)
+
+	// values = make([]float64, len(values))
+	// for i, v := range buffer {
+	// 	if v > 0 {
+	// 		values[i] = 1
+	// 	} else {
+	// 		values[i] = 0
+	// 	}
+	// }
+
+	// squareDense := mat.NewDense(28, 28, values)
+	// fmt.Printf("r.buffered:\n%v\n\n", mat.Formatted(squareDense, mat.Prefix(""), mat.Excerpt(0), mat.Squeeze()))
+
+	buffer = make([]byte, 1)
+	_, err = r.resultsFile.Read(buffer)
+	if err != nil {
+		log.Fatal("Result file read error\n")
+	}
+
+	num := int(buffer[0])
+
+	r.resultsBuffered = mat.NewDense(10, 1, nil)
+	r.resultsBuffered.Set(num, 0, 1.0)
+
+	// fmt.Printf("r.resultsBuffered:\n%v\n\n", mat.Formatted(r.resultsBuffered, mat.Prefix(""), mat.Excerpt(0)))
+
+	return true
+}
+
+func (r *MNISTReader) Reset() {
+	r.file.Seek(16, 0)
+	r.resultsFile.Seek(8, 0)
+}