/* * MIT License * * Copyright (c) 2019 Alexey Edelev * * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork * * Permission is hereby granted, free of charge, to any person obtaining a copy of this * software and associated documentation files (the "Software"), to deal in the Software * without restriction, including without limitation the rights to use, copy, modify, * merge, publish, distribute, sublicense, and/or sell copies of the Software, and * to permit persons to whom the Software is furnished to do so, subject to the following * conditions: * * The above copyright notice and this permission notice shall be included in all copies * or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER * DEALINGS IN THE SOFTWARE. */ package training import ( "encoding/binary" "io" "log" "os" mat "gonum.org/v1/gonum/mat" ) type mnistReader struct { dataFilename string resultsFilename string validatorFilename string validatorResultsFilename string dataCount int validatorCount int imageSize int buffered *mat.Dense resultsBuffered *mat.Dense bufferedValidation *mat.Dense resultsBufferedValidation *mat.Dense } func NewMNISTReader(dataFilename string, resultsFilename string, validatorFilename string, validatorResultsFilename string) (r *mnistReader) { r = &mnistReader{} r.dataCount, r.imageSize = openFileSet(dataFilename, resultsFilename) r.validatorCount, _ = openFileSet(validatorFilename, validatorResultsFilename) if r.dataCount <= 0 || r.imageSize <= 0 || r.validatorCount <= 0 { return nil } return } func (r *mnistReader) GetData(i int) (*mat.Dense, *mat.Dense) { if r.dataCount <= i { return nil, nil } return r.readData(r.dataFilename, r.resultsFilename, i) } func (r *mnistReader) DataCount() int { return r.dataCount } func (r *mnistReader) GetValidator(i int) (data *mat.Dense, result *mat.Dense) { if r.validatorCount <= i { return nil, nil } return r.readData(r.validatorFilename, r.validatorResultsFilename, i) } func (r *mnistReader) ValidatorCount() int { return r.validatorCount } func (r *mnistReader) readData(data string, result string, i int) (buffered, resultsBuffered *mat.Dense) { file, err := os.Open(data) if err != nil { return nil, nil } defer file.Close() resultsFile, err := os.Open(result) if err != nil { return nil, nil } defer resultsFile.Close() file.Seek(16+int64(r.imageSize*i), 0) resultsFile.Seek(8+int64(i), 0) buffer := make([]byte, r.imageSize) _, err = file.Read(buffer) if err == io.EOF { return nil, nil } else if err != nil { log.Fatal("File read error\n") } values := make([]float64, r.imageSize) for i, v := range buffer { values[i] = float64(v) / 255.0 } buffered = mat.NewDense(r.imageSize, 1, values) buffer = make([]byte, 1) _, err = resultsFile.Read(buffer) if err != nil { log.Fatal("Result file read error\n") } num := int(buffer[0]) resultsBuffered = mat.NewDense(10, 1, nil) resultsBuffered.Set(num, 0, 1.0) return buffered, resultsBuffered } func openFileSet(dataFilename string, resultsFilename string) (count int, imageSize int) { var err error data, err := os.Open(dataFilename) if err != nil { return -1, -1 } defer data.Close() result, err := os.Open(resultsFilename) if err != nil { return -1, -1 } defer result.Close() buffer := make([]byte, 16) data.Read(buffer) header := binary.BigEndian.Uint32(buffer[:4]) if header != 0x00000803 { return -1, -1 } count = int(binary.BigEndian.Uint32(buffer[4:8])) imageSize = int(binary.BigEndian.Uint32(buffer[8:12])) * int(binary.BigEndian.Uint32(buffer[12:16])) buffer = make([]byte, 8) result.Read(buffer) header = binary.BigEndian.Uint32(buffer[0:4]) if header != 0x00000801 { return -1, -1 } resultsCount := int(binary.BigEndian.Uint32(buffer[4:8])) if resultsCount != count { return -1, -1 } return }