/* * MIT License * * Copyright (c) 2019 Alexey Edelev * * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork * * Permission is hereby granted, free of charge, to any person obtaining a copy of this * software and associated documentation files (the "Software"), to deal in the Software * without restriction, including without limitation the rights to use, copy, modify, * merge, publish, distribute, sublicense, and/or sell copies of the Software, and * to permit persons to whom the Software is furnished to do so, subject to the following * conditions: * * The above copyright notice and this permission notice shall be included in all copies * or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER * DEALINGS IN THE SOFTWARE. */ package training import ( "encoding/binary" "fmt" "io" "log" "os" mat "gonum.org/v1/gonum/mat" ) type MNISTReader struct { file *os.File resultsFile *os.File fileValidation *os.File resultsFileValidation *os.File size int imageSize int buffered *mat.Dense resultsBuffered *mat.Dense bufferedValidation *mat.Dense resultsBufferedValidation *mat.Dense } func NewMNISTReader(dataFilename string, resultsFilename string) (r *MNISTReader) { r = &MNISTReader{} var err error r.file, err = os.Open(dataFilename) if err != nil { return nil } r.resultsFile, err = os.Open(resultsFilename) if err != nil { return nil } buffer := make([]byte, 16) r.file.Read(buffer) header := binary.BigEndian.Uint32(buffer[:4]) if header != 0x00000803 { return nil } r.size = int(binary.BigEndian.Uint32(buffer[4:8])) r.imageSize = int(binary.BigEndian.Uint32(buffer[8:12])) * int(binary.BigEndian.Uint32(buffer[12:16])) fmt.Printf("Image size: %v\n", r.imageSize) buffer = make([]byte, 8) r.resultsFile.Read(buffer) header = binary.BigEndian.Uint32(buffer[0:4]) if header != 0x00000801 { return nil } resultsSize := int(binary.BigEndian.Uint32(buffer[4:8])) if resultsSize != r.size { return nil } //Separation validation part r.fileValidation, err = os.Open(dataFilename) if err != nil { return nil } r.resultsFileValidation, err = os.Open(resultsFilename) if err != nil { return nil } r.Reset() return } func (r *MNISTReader) GetData() (*mat.Dense, *mat.Dense) { return r.buffered, r.resultsBuffered } func (r *MNISTReader) NextData() bool { r.buffered, r.resultsBuffered = r.readNextData(r.fileValidation, r.resultsFileValidation) if r.buffered != nil && r.resultsBuffered != nil { return true } r.Reset() return false } func (r *MNISTReader) Reset() { r.file.Seek(16, 0) r.resultsFile.Seek(8, 0) r.fileValidation.Seek(16, 0) r.resultsFileValidation.Seek(8, 0) } func (r *MNISTReader) GetValidator() (*mat.Dense, *mat.Dense) { return r.bufferedValidation, r.resultsBufferedValidation } func (r *MNISTReader) NextValidator() bool { r.bufferedValidation, r.resultsBufferedValidation = r.readNextData(r.fileValidation, r.resultsFileValidation) if r.bufferedValidation != nil && r.resultsBufferedValidation != nil { return true } r.Reset() return false } func (r *MNISTReader) readNextData(file *os.File, resultsFile *os.File) (buffered, resultsBuffered *mat.Dense) { buffer := make([]byte, r.imageSize) _, err := file.Read(buffer) if err == io.EOF { return nil, nil } else if err != nil { log.Fatal("File read error\n") } values := make([]float64, r.imageSize) for i, v := range buffer { values[i] = float64(v) / 255.0 } buffered = mat.NewDense(r.imageSize, 1, values) buffer = make([]byte, 1) _, err = resultsFile.Read(buffer) if err != nil { log.Fatal("Result file read error\n") } num := int(buffer[0]) resultsBuffered = mat.NewDense(10, 1, nil) resultsBuffered.Set(num, 0, 1.0) return buffered, resultsBuffered } func (r *MNISTReader) GetDataByIndex(i int) (*mat.Dense, *mat.Dense) { file, err := os.Open(r.file.Name()) if err != nil { return nil, nil } defer file.Close() resultsFile, err := os.Open(r.resultsFile.Name()) if err != nil { return nil, nil } defer resultsFile.Close() file.Seek(16+int64(r.imageSize*i), 0) resultsFile.Seek(8+int64(i), 0) return r.readNextData(file, resultsFile) } func (r *MNISTReader) GetDataCount() int { return r.size }