|
@@ -36,25 +36,20 @@ import (
|
|
|
)
|
|
|
|
|
|
type MNISTReader struct {
|
|
|
- file *os.File
|
|
|
- resultsFile *os.File
|
|
|
- size int
|
|
|
- imageSize int
|
|
|
- buffered *mat.Dense
|
|
|
- resultsBuffered *mat.Dense
|
|
|
- window MNISTBatchWindow
|
|
|
- currentIndex int64
|
|
|
+ file *os.File
|
|
|
+ resultsFile *os.File
|
|
|
+ fileValidation *os.File
|
|
|
+ resultsFileValidation *os.File
|
|
|
+ size int
|
|
|
+ imageSize int
|
|
|
+ buffered *mat.Dense
|
|
|
+ resultsBuffered *mat.Dense
|
|
|
+ bufferedValidation *mat.Dense
|
|
|
+ resultsBufferedValidation *mat.Dense
|
|
|
}
|
|
|
|
|
|
-type MNISTBatchWindow struct {
|
|
|
- from int64
|
|
|
- to int64
|
|
|
-}
|
|
|
-
|
|
|
-func NewMNISTReader(dataFilename string, resultsFilename string, window MNISTBatchWindow) (r *MNISTReader) {
|
|
|
- r = &MNISTReader{
|
|
|
- window: window,
|
|
|
- }
|
|
|
+func NewMNISTReader(dataFilename string, resultsFilename string) (r *MNISTReader) {
|
|
|
+ r = &MNISTReader{}
|
|
|
|
|
|
var err error
|
|
|
r.file, err = os.Open(dataFilename)
|
|
@@ -87,25 +82,61 @@ func NewMNISTReader(dataFilename string, resultsFilename string, window MNISTBat
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+ //Separation validation part
|
|
|
+ r.fileValidation, err = os.Open(dataFilename)
|
|
|
+ if err != nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ r.resultsFileValidation, err = os.Open(resultsFilename)
|
|
|
+ if err != nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
r.Reset()
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func (r *MNISTReader) GetData() *mat.Dense {
|
|
|
- return r.buffered
|
|
|
+func (r *MNISTReader) GetData() (*mat.Dense, *mat.Dense) {
|
|
|
+ return r.buffered, r.resultsBuffered
|
|
|
+}
|
|
|
+
|
|
|
+func (r *MNISTReader) NextData() bool {
|
|
|
+ r.buffered, r.resultsBuffered = r.readNextData(r.fileValidation, r.resultsFileValidation)
|
|
|
+ if r.buffered != nil && r.resultsBuffered != nil {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ r.Reset()
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
+func (r *MNISTReader) Reset() {
|
|
|
+ r.file.Seek(16, 0)
|
|
|
+ r.resultsFile.Seek(8, 0)
|
|
|
+
|
|
|
+ r.fileValidation.Seek(16, 0)
|
|
|
+ r.resultsFileValidation.Seek(8, 0)
|
|
|
}
|
|
|
|
|
|
-func (r *MNISTReader) GetExpect() *mat.Dense {
|
|
|
- return r.resultsBuffered
|
|
|
+func (r *MNISTReader) GetValidator() (*mat.Dense, *mat.Dense) {
|
|
|
+ return r.bufferedValidation, r.resultsBufferedValidation
|
|
|
}
|
|
|
|
|
|
-func (r *MNISTReader) Next() bool {
|
|
|
+func (r *MNISTReader) NextValidator() bool {
|
|
|
+ r.bufferedValidation, r.resultsBufferedValidation = r.readNextData(r.fileValidation, r.resultsFileValidation)
|
|
|
+ if r.bufferedValidation != nil && r.resultsBufferedValidation != nil {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ r.Reset()
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
+func (r *MNISTReader) readNextData(file *os.File, resultsFile *os.File) (buffered, resultsBuffered *mat.Dense) {
|
|
|
buffer := make([]byte, r.imageSize)
|
|
|
- _, err := r.file.Read(buffer)
|
|
|
+ _, err := file.Read(buffer)
|
|
|
|
|
|
- if err == io.EOF || r.currentIndex >= r.window.to {
|
|
|
- r.Reset()
|
|
|
- return false
|
|
|
+ if err == io.EOF {
|
|
|
+ return nil, nil
|
|
|
} else if err != nil {
|
|
|
log.Fatal("File read error\n")
|
|
|
}
|
|
@@ -115,38 +146,41 @@ func (r *MNISTReader) Next() bool {
|
|
|
values[i] = float64(v) / 255.0
|
|
|
}
|
|
|
|
|
|
- r.buffered = mat.NewDense(r.imageSize, 1, values)
|
|
|
-
|
|
|
- // values = make([]float64, len(values))
|
|
|
- // for i, v := range buffer {
|
|
|
- // if v > 0 {
|
|
|
- // values[i] = 1
|
|
|
- // } else {
|
|
|
- // values[i] = 0
|
|
|
- // }
|
|
|
- // }
|
|
|
-
|
|
|
- // squareDense := mat.NewDense(28, 28, values)
|
|
|
- // fmt.Printf("r.buffered:\n%v\n\n", mat.Formatted(squareDense, mat.Prefix(""), mat.Excerpt(0), mat.Squeeze()))
|
|
|
+ buffered = mat.NewDense(r.imageSize, 1, values)
|
|
|
|
|
|
buffer = make([]byte, 1)
|
|
|
- _, err = r.resultsFile.Read(buffer)
|
|
|
+ _, err = resultsFile.Read(buffer)
|
|
|
if err != nil {
|
|
|
log.Fatal("Result file read error\n")
|
|
|
}
|
|
|
|
|
|
num := int(buffer[0])
|
|
|
|
|
|
- r.resultsBuffered = mat.NewDense(10, 1, nil)
|
|
|
- r.resultsBuffered.Set(num, 0, 1.0)
|
|
|
+ resultsBuffered = mat.NewDense(10, 1, nil)
|
|
|
+ resultsBuffered.Set(num, 0, 1.0)
|
|
|
|
|
|
- // fmt.Printf("r.resultsBuffered:\n%v\n\n", mat.Formatted(r.resultsBuffered, mat.Prefix(""), mat.Excerpt(0)))
|
|
|
- r.currentIndex++
|
|
|
+ return buffered, resultsBuffered
|
|
|
+}
|
|
|
|
|
|
- return true
|
|
|
+func (r *MNISTReader) GetDataByIndex(i int) (*mat.Dense, *mat.Dense) {
|
|
|
+ file, err := os.Open(r.file.Name())
|
|
|
+ if err != nil {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+ defer file.Close()
|
|
|
+
|
|
|
+ resultsFile, err := os.Open(r.resultsFile.Name())
|
|
|
+ if err != nil {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+ defer resultsFile.Close()
|
|
|
+
|
|
|
+ file.Seek(16+int64(r.imageSize*i), 0)
|
|
|
+ resultsFile.Seek(8+int64(i), 0)
|
|
|
+
|
|
|
+ return r.readNextData(file, resultsFile)
|
|
|
}
|
|
|
|
|
|
-func (r *MNISTReader) Reset() {
|
|
|
- r.file.Seek(16+r.window.from*int64(r.imageSize), 0)
|
|
|
- r.resultsFile.Seek(8+r.window.from*int64(r.imageSize), 0)
|
|
|
+func (r *MNISTReader) GetDataCount() int {
|
|
|
+ return r.size
|
|
|
}
|