|
@@ -27,7 +27,6 @@ package training
|
|
|
|
|
|
import (
|
|
|
"encoding/binary"
|
|
|
- "fmt"
|
|
|
"io"
|
|
|
"log"
|
|
|
"os"
|
|
@@ -35,12 +34,13 @@ import (
|
|
|
mat "gonum.org/v1/gonum/mat"
|
|
|
)
|
|
|
|
|
|
-type MNISTReader struct {
|
|
|
- file *os.File
|
|
|
- resultsFile *os.File
|
|
|
- fileValidation *os.File
|
|
|
- resultsFileValidation *os.File
|
|
|
- size int
|
|
|
+type mnistReader struct {
|
|
|
+ dataFilename string
|
|
|
+ resultsFilename string
|
|
|
+ validatorFilename string
|
|
|
+ validatorResultsFilename string
|
|
|
+ dataCount int
|
|
|
+ validatorCount int
|
|
|
imageSize int
|
|
|
buffered *mat.Dense
|
|
|
resultsBuffered *mat.Dense
|
|
@@ -48,92 +48,59 @@ type MNISTReader struct {
|
|
|
resultsBufferedValidation *mat.Dense
|
|
|
}
|
|
|
|
|
|
-func NewMNISTReader(dataFilename string, resultsFilename string) (r *MNISTReader) {
|
|
|
- r = &MNISTReader{}
|
|
|
+func NewMNISTReader(dataFilename string, resultsFilename string, validatorFilename string, validatorResultsFilename string) (r *mnistReader) {
|
|
|
+ r = &mnistReader{}
|
|
|
|
|
|
- var err error
|
|
|
- r.file, err = os.Open(dataFilename)
|
|
|
- if err != nil {
|
|
|
- return nil
|
|
|
- }
|
|
|
-
|
|
|
- r.resultsFile, err = os.Open(resultsFilename)
|
|
|
- if err != nil {
|
|
|
- return nil
|
|
|
- }
|
|
|
-
|
|
|
- buffer := make([]byte, 16)
|
|
|
- r.file.Read(buffer)
|
|
|
- header := binary.BigEndian.Uint32(buffer[:4])
|
|
|
- if header != 0x00000803 {
|
|
|
- return nil
|
|
|
- }
|
|
|
- r.size = int(binary.BigEndian.Uint32(buffer[4:8]))
|
|
|
- r.imageSize = int(binary.BigEndian.Uint32(buffer[8:12])) * int(binary.BigEndian.Uint32(buffer[12:16]))
|
|
|
- fmt.Printf("Image size: %v\n", r.imageSize)
|
|
|
- buffer = make([]byte, 8)
|
|
|
- r.resultsFile.Read(buffer)
|
|
|
- header = binary.BigEndian.Uint32(buffer[0:4])
|
|
|
- if header != 0x00000801 {
|
|
|
- return nil
|
|
|
- }
|
|
|
- resultsSize := int(binary.BigEndian.Uint32(buffer[4:8]))
|
|
|
- if resultsSize != r.size {
|
|
|
- return nil
|
|
|
- }
|
|
|
-
|
|
|
- //Separation validation part
|
|
|
- r.fileValidation, err = os.Open(dataFilename)
|
|
|
- if err != nil {
|
|
|
+ r.dataCount, r.imageSize = openFileSet(dataFilename, resultsFilename)
|
|
|
+ r.validatorCount, _ = openFileSet(validatorFilename, validatorResultsFilename)
|
|
|
+ if r.dataCount <= 0 || r.imageSize <= 0 || r.validatorCount <= 0 {
|
|
|
return nil
|
|
|
}
|
|
|
+ return
|
|
|
+}
|
|
|
|
|
|
- r.resultsFileValidation, err = os.Open(resultsFilename)
|
|
|
- if err != nil {
|
|
|
- return nil
|
|
|
+func (r *mnistReader) GetData(i int) (*mat.Dense, *mat.Dense) {
|
|
|
+ if r.dataCount <= i {
|
|
|
+ return nil, nil
|
|
|
}
|
|
|
|
|
|
- r.Reset()
|
|
|
- return
|
|
|
+ return r.readData(r.dataFilename, r.resultsFilename, i)
|
|
|
}
|
|
|
|
|
|
-func (r *MNISTReader) GetData() (*mat.Dense, *mat.Dense) {
|
|
|
- return r.buffered, r.resultsBuffered
|
|
|
+func (r *mnistReader) DataCount() int {
|
|
|
+ return r.dataCount
|
|
|
}
|
|
|
|
|
|
-func (r *MNISTReader) NextData() bool {
|
|
|
- r.buffered, r.resultsBuffered = r.readNextData(r.fileValidation, r.resultsFileValidation)
|
|
|
- if r.buffered != nil && r.resultsBuffered != nil {
|
|
|
- return true
|
|
|
+func (r *mnistReader) GetValidator(i int) (data *mat.Dense, result *mat.Dense) {
|
|
|
+ if r.validatorCount <= i {
|
|
|
+ return nil, nil
|
|
|
}
|
|
|
- r.Reset()
|
|
|
- return false
|
|
|
-}
|
|
|
|
|
|
-func (r *MNISTReader) Reset() {
|
|
|
- r.file.Seek(16, 0)
|
|
|
- r.resultsFile.Seek(8, 0)
|
|
|
-
|
|
|
- r.fileValidation.Seek(16, 0)
|
|
|
- r.resultsFileValidation.Seek(8, 0)
|
|
|
+ return r.readData(r.validatorFilename, r.validatorResultsFilename, i)
|
|
|
}
|
|
|
|
|
|
-func (r *MNISTReader) GetValidator() (*mat.Dense, *mat.Dense) {
|
|
|
- return r.bufferedValidation, r.resultsBufferedValidation
|
|
|
+func (r *mnistReader) ValidatorCount() int {
|
|
|
+ return r.validatorCount
|
|
|
}
|
|
|
|
|
|
-func (r *MNISTReader) NextValidator() bool {
|
|
|
- r.bufferedValidation, r.resultsBufferedValidation = r.readNextData(r.fileValidation, r.resultsFileValidation)
|
|
|
- if r.bufferedValidation != nil && r.resultsBufferedValidation != nil {
|
|
|
- return true
|
|
|
+func (r *mnistReader) readData(data string, result string, i int) (buffered, resultsBuffered *mat.Dense) {
|
|
|
+ file, err := os.Open(data)
|
|
|
+ if err != nil {
|
|
|
+ return nil, nil
|
|
|
}
|
|
|
- r.Reset()
|
|
|
- return false
|
|
|
-}
|
|
|
+ defer file.Close()
|
|
|
+
|
|
|
+ resultsFile, err := os.Open(result)
|
|
|
+ if err != nil {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+ defer resultsFile.Close()
|
|
|
+
|
|
|
+ file.Seek(16+int64(r.imageSize*i), 0)
|
|
|
+ resultsFile.Seek(8+int64(i), 0)
|
|
|
|
|
|
-func (r *MNISTReader) readNextData(file *os.File, resultsFile *os.File) (buffered, resultsBuffered *mat.Dense) {
|
|
|
buffer := make([]byte, r.imageSize)
|
|
|
- _, err := file.Read(buffer)
|
|
|
+ _, err = file.Read(buffer)
|
|
|
|
|
|
if err == io.EOF {
|
|
|
return nil, nil
|
|
@@ -162,25 +129,39 @@ func (r *MNISTReader) readNextData(file *os.File, resultsFile *os.File) (buffere
|
|
|
return buffered, resultsBuffered
|
|
|
}
|
|
|
|
|
|
-func (r *MNISTReader) GetDataByIndex(i int) (*mat.Dense, *mat.Dense) {
|
|
|
- file, err := os.Open(r.file.Name())
|
|
|
+func openFileSet(dataFilename string, resultsFilename string) (count int, imageSize int) {
|
|
|
+ var err error
|
|
|
+ data, err := os.Open(dataFilename)
|
|
|
if err != nil {
|
|
|
- return nil, nil
|
|
|
+ return -1, -1
|
|
|
}
|
|
|
- defer file.Close()
|
|
|
+ defer data.Close()
|
|
|
|
|
|
- resultsFile, err := os.Open(r.resultsFile.Name())
|
|
|
+ result, err := os.Open(resultsFilename)
|
|
|
if err != nil {
|
|
|
- return nil, nil
|
|
|
+ return -1, -1
|
|
|
}
|
|
|
- defer resultsFile.Close()
|
|
|
+ defer result.Close()
|
|
|
|
|
|
- file.Seek(16+int64(r.imageSize*i), 0)
|
|
|
- resultsFile.Seek(8+int64(i), 0)
|
|
|
+ buffer := make([]byte, 16)
|
|
|
+ data.Read(buffer)
|
|
|
+ header := binary.BigEndian.Uint32(buffer[:4])
|
|
|
+ if header != 0x00000803 {
|
|
|
+ return -1, -1
|
|
|
+ }
|
|
|
+ count = int(binary.BigEndian.Uint32(buffer[4:8]))
|
|
|
+ imageSize = int(binary.BigEndian.Uint32(buffer[8:12])) * int(binary.BigEndian.Uint32(buffer[12:16]))
|
|
|
|
|
|
- return r.readNextData(file, resultsFile)
|
|
|
-}
|
|
|
+ buffer = make([]byte, 8)
|
|
|
+ result.Read(buffer)
|
|
|
+ header = binary.BigEndian.Uint32(buffer[0:4])
|
|
|
+ if header != 0x00000801 {
|
|
|
+ return -1, -1
|
|
|
+ }
|
|
|
+ resultsCount := int(binary.BigEndian.Uint32(buffer[4:8]))
|
|
|
+ if resultsCount != count {
|
|
|
+ return -1, -1
|
|
|
+ }
|
|
|
|
|
|
-func (r *MNISTReader) GetDataCount() int {
|
|
|
- return r.size
|
|
|
+ return
|
|
|
}
|