فهرست منبع

Implement multi-thread mnist reader

Alexey Edelev 5 سال پیش
والد
کامیت
59a583d90a

+ 5 - 5
neuralnetwork/main.go

@@ -10,8 +10,8 @@ import (
 )
 
 func main() {
-	sizes := []int{13, 14, 14, 3}
-	nn, _ := neuralnetwork.NewNeuralNetwork(sizes, 130, neuralnetwork.NewRPropInitializer(neuralnetwork.RPropConfig{
+	sizes := []int{784, 16, 16, 10}
+	nn, _ := neuralnetwork.NewNeuralNetwork(sizes, 1, neuralnetwork.NewRPropInitializer(neuralnetwork.RPropConfig{
 		NuPlus:   1.2,
 		NuMinus:  0.8,
 		DeltaMax: 50.0,
@@ -29,7 +29,7 @@ func main() {
 	// 	fmt.Printf("A before:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
 	// }
 
-	teacher := teach.NewTextDataReader("./wine.data", 1)
+	teacher := teach.NewMNISTReader("./minst.data", "./mnist.labels")
 	nn.Teach(teacher)
 
 	// for i := 0; i < nn.Count; i++ {
@@ -56,7 +56,7 @@ func main() {
 		index, _ := nn.Predict(dataSet)
 		if expect.At(index, 0) != 1.0 {
 			failCount++
-			fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
+			// fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
 		}
 	}
 	fmt.Printf("Fail count: %v\n\n", failCount)
@@ -77,7 +77,7 @@ func main() {
 		index, _ := nn.Predict(dataSet)
 		if expect.At(index, 0) != 1.0 {
 			failCount++
-			fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
+			// fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
 		}
 	}
 

+ 3 - 3
neuralnetwork/neuralnetworkbase/batchworker.go

@@ -51,9 +51,9 @@ func newBatchWorker(nn *NeuralNetwork) (bw *batchWorker) {
 	return
 }
 
-func (bw *batchWorker) Run(teacher teach.Teacher) {
-	for teacher.NextData() {
-		dB, dW := bw.network.backward(teacher.GetData())
+func (bw *batchWorker) Run(teacher teach.Teacher, startIndex, endIndex int) {
+	for i := startIndex; i < endIndex; i++ {
+		dB, dW := bw.network.backward(teacher.GetDataByIndex(i))
 		for l := 1; l < bw.network.LayerCount; l++ {
 			bw.BGradient[l].AccumGradients(dB[l])
 			bw.WGradient[l].AccumGradients(dW[l])

+ 4 - 4
neuralnetwork/neuralnetworkbase/neuralnetwork.go

@@ -202,15 +202,15 @@ func (nn *NeuralNetwork) TeachOnline(teacher teach.Teacher) {
 	}
 }
 
-func (nn *NeuralNetwork) TeachBatch(_ teach.Teacher) {
+func (nn *NeuralNetwork) TeachBatch(teacher teach.Teacher) {
 	wg := sync.WaitGroup{}
 	for t := 0; t < nn.epocs; t++ {
-		batchWorkers := []*batchWorker{newBatchWorker(nn), newBatchWorker(nn), newBatchWorker(nn)} //, newBatchWorker(nn), newBatchWorker(nn), newBatchWorker(nn)}
+		batchWorkers := []*batchWorker{newBatchWorker(nn), newBatchWorker(nn), newBatchWorker(nn), newBatchWorker(nn), newBatchWorker(nn), newBatchWorker(nn), newBatchWorker(nn), newBatchWorker(nn)}
 		for i, _ := range batchWorkers {
 			wg.Add(1)
+			s := i
 			go func() {
-				teacher := teach.NewTextDataReader("./wine.data", 5)
-				batchWorkers[i].Run(teacher)
+				batchWorkers[s].Run(teacher, s*teacher.GetDataCount()/len(batchWorkers), (s+1)*teacher.GetDataCount()/len(batchWorkers))
 				wg.Done()
 			}()
 		}

+ 82 - 48
neuralnetwork/teach/mnistreader.go

@@ -36,25 +36,20 @@ import (
 )
 
 type MNISTReader struct {
-	file            *os.File
-	resultsFile     *os.File
-	size            int
-	imageSize       int
-	buffered        *mat.Dense
-	resultsBuffered *mat.Dense
-	window          MNISTBatchWindow
-	currentIndex    int64
+	file                      *os.File
+	resultsFile               *os.File
+	fileValidation            *os.File
+	resultsFileValidation     *os.File
+	size                      int
+	imageSize                 int
+	buffered                  *mat.Dense
+	resultsBuffered           *mat.Dense
+	bufferedValidation        *mat.Dense
+	resultsBufferedValidation *mat.Dense
 }
 
-type MNISTBatchWindow struct {
-	from int64
-	to   int64
-}
-
-func NewMNISTReader(dataFilename string, resultsFilename string, window MNISTBatchWindow) (r *MNISTReader) {
-	r = &MNISTReader{
-		window: window,
-	}
+func NewMNISTReader(dataFilename string, resultsFilename string) (r *MNISTReader) {
+	r = &MNISTReader{}
 
 	var err error
 	r.file, err = os.Open(dataFilename)
@@ -87,25 +82,61 @@ func NewMNISTReader(dataFilename string, resultsFilename string, window MNISTBat
 		return nil
 	}
 
+	//Separation validation part
+	r.fileValidation, err = os.Open(dataFilename)
+	if err != nil {
+		return nil
+	}
+
+	r.resultsFileValidation, err = os.Open(resultsFilename)
+	if err != nil {
+		return nil
+	}
+
 	r.Reset()
 	return
 }
 
-func (r *MNISTReader) GetData() *mat.Dense {
-	return r.buffered
+func (r *MNISTReader) GetData() (*mat.Dense, *mat.Dense) {
+	return r.buffered, r.resultsBuffered
+}
+
+func (r *MNISTReader) NextData() bool {
+	r.buffered, r.resultsBuffered = r.readNextData(r.fileValidation, r.resultsFileValidation)
+	if r.buffered != nil && r.resultsBuffered != nil {
+		return true
+	}
+	r.Reset()
+	return false
+}
+
+func (r *MNISTReader) Reset() {
+	r.file.Seek(16, 0)
+	r.resultsFile.Seek(8, 0)
+
+	r.fileValidation.Seek(16, 0)
+	r.resultsFileValidation.Seek(8, 0)
 }
 
-func (r *MNISTReader) GetExpect() *mat.Dense {
-	return r.resultsBuffered
+func (r *MNISTReader) GetValidator() (*mat.Dense, *mat.Dense) {
+	return r.bufferedValidation, r.resultsBufferedValidation
 }
 
-func (r *MNISTReader) Next() bool {
+func (r *MNISTReader) NextValidator() bool {
+	r.bufferedValidation, r.resultsBufferedValidation = r.readNextData(r.fileValidation, r.resultsFileValidation)
+	if r.bufferedValidation != nil && r.resultsBufferedValidation != nil {
+		return true
+	}
+	r.Reset()
+	return false
+}
+
+func (r *MNISTReader) readNextData(file *os.File, resultsFile *os.File) (buffered, resultsBuffered *mat.Dense) {
 	buffer := make([]byte, r.imageSize)
-	_, err := r.file.Read(buffer)
+	_, err := file.Read(buffer)
 
-	if err == io.EOF || r.currentIndex >= r.window.to {
-		r.Reset()
-		return false
+	if err == io.EOF {
+		return nil, nil
 	} else if err != nil {
 		log.Fatal("File read error\n")
 	}
@@ -115,38 +146,41 @@ func (r *MNISTReader) Next() bool {
 		values[i] = float64(v) / 255.0
 	}
 
-	r.buffered = mat.NewDense(r.imageSize, 1, values)
-
-	// values = make([]float64, len(values))
-	// for i, v := range buffer {
-	// 	if v > 0 {
-	// 		values[i] = 1
-	// 	} else {
-	// 		values[i] = 0
-	// 	}
-	// }
-
-	// squareDense := mat.NewDense(28, 28, values)
-	// fmt.Printf("r.buffered:\n%v\n\n", mat.Formatted(squareDense, mat.Prefix(""), mat.Excerpt(0), mat.Squeeze()))
+	buffered = mat.NewDense(r.imageSize, 1, values)
 
 	buffer = make([]byte, 1)
-	_, err = r.resultsFile.Read(buffer)
+	_, err = resultsFile.Read(buffer)
 	if err != nil {
 		log.Fatal("Result file read error\n")
 	}
 
 	num := int(buffer[0])
 
-	r.resultsBuffered = mat.NewDense(10, 1, nil)
-	r.resultsBuffered.Set(num, 0, 1.0)
+	resultsBuffered = mat.NewDense(10, 1, nil)
+	resultsBuffered.Set(num, 0, 1.0)
 
-	// fmt.Printf("r.resultsBuffered:\n%v\n\n", mat.Formatted(r.resultsBuffered, mat.Prefix(""), mat.Excerpt(0)))
-	r.currentIndex++
+	return buffered, resultsBuffered
+}
 
-	return true
+func (r *MNISTReader) GetDataByIndex(i int) (*mat.Dense, *mat.Dense) {
+	file, err := os.Open(r.file.Name())
+	if err != nil {
+		return nil, nil
+	}
+	defer file.Close()
+
+	resultsFile, err := os.Open(r.resultsFile.Name())
+	if err != nil {
+		return nil, nil
+	}
+	defer resultsFile.Close()
+
+	file.Seek(16+int64(r.imageSize*i), 0)
+	resultsFile.Seek(8+int64(i), 0)
+
+	return r.readNextData(file, resultsFile)
 }
 
-func (r *MNISTReader) Reset() {
-	r.file.Seek(16+r.window.from*int64(r.imageSize), 0)
-	r.resultsFile.Seek(8+r.window.from*int64(r.imageSize), 0)
+func (r *MNISTReader) GetDataCount() int {
+	return r.size
 }

+ 3 - 0
neuralnetwork/teach/teacher.go

@@ -33,6 +33,9 @@ type Teacher interface {
 	GetData() (*mat.Dense, *mat.Dense)
 	NextData() bool
 
+	GetDataByIndex(i int) (*mat.Dense, *mat.Dense)
+	GetDataCount() int
+
 	GetValidator() (*mat.Dense, *mat.Dense)
 	NextValidator() bool
 

+ 16 - 4
neuralnetwork/teach/textdatareader.go

@@ -147,14 +147,14 @@ func (r *TextDataReader) readData(filename string) {
 }
 
 func (r *TextDataReader) GetData() (*mat.Dense, *mat.Dense) {
-	r.mutex.Lock()
-	defer r.mutex.Unlock()
+	// r.mutex.Lock()
+	// defer r.mutex.Unlock()
 	return r.dataSet[r.index], r.result[r.index]
 }
 
 func (r *TextDataReader) NextData() bool {
-	r.mutex.Lock()
-	defer r.mutex.Unlock()
+	// r.mutex.Lock()
+	// defer r.mutex.Unlock()
 	if (r.index + 1) >= len(r.result)-r.validationCount {
 		r.index = 0
 		return false
@@ -190,3 +190,15 @@ func (r *TextDataReader) Index() int {
 func (r *TextDataReader) ValidationIndex() int {
 	return r.validationIndex
 }
+
+func (r *TextDataReader) GetDataByIndex(i int) (*mat.Dense, *mat.Dense) {
+	if i >= len(r.result)-r.validationCount {
+		return nil, nil
+	}
+
+	return r.dataSet[i], r.result[i]
+}
+
+func (r *TextDataReader) GetDataCount() int {
+	return len(r.dataSet) - r.validationCount
+}