Prechádzať zdrojové kódy

Rework Trainer interface

- Make trainer interface index-based only
- Add explicit validators to mnist and textdata readers
Alexey Edelev 5 rokov pred
rodič
commit
c00b3dcca7

+ 1 - 0
genetic/genetic.go

@@ -98,6 +98,7 @@ func (p *Population) NaturalSelection(generationCount int) {
 	}
 }
 
+// GetBestNetwork method returns best network in population according to it fitness
 func (p *Population) GetBestNetwork() *neuralnetwork.NeuralNetwork {
 	return p.bestNetwork
 }

+ 1 - 1
genetic/mutagens/dummymutagen.go

@@ -48,7 +48,7 @@ func NewDummyMutagen(chance float64, mutationCount int) (dm *DummyMutagen) {
 	return
 }
 
-// Mutate implementaion of Mutagen inteface Mutate method
+// Dummy implementaion of Mutagen inteface Mutate method
 // For DummyMutagen it gets pseudo-random number and validates if number in
 // chance bounds. After method applies randomized mutation for random weight
 // and bias in neuralnetwork.NeuralNetwork

+ 1 - 2
neuralnetwork/batchworker.go

@@ -53,13 +53,12 @@ func newBatchWorker(nn *NeuralNetwork) (bw *batchWorker) {
 
 func (bw *batchWorker) run(trainer training.Trainer, startIndex, endIndex int) {
 	for i := startIndex; i < endIndex; i++ {
-		dB, dW := bw.network.backward(trainer.GetDataByIndex(i))
+		dB, dW := bw.network.backward(trainer.GetData(i))
 		for l := 1; l < bw.network.LayerCount; l++ {
 			bw.BGradient[l].AccumGradients(dB[l])
 			bw.WGradient[l].AccumGradients(dW[l])
 		}
 	}
-	trainer.Reset()
 }
 
 func (bw *batchWorker) result(layer int) (dB, dW *mat.Dense) {

+ 3 - 0
neuralnetwork/gradient.go

@@ -34,12 +34,15 @@ const (
 	WeightGradient = iota
 )
 
+// Interface factory function type that specifies producer interface for gradient descent objects
 type GradientDescentInitializer func(nn *NeuralNetwork, layer, gradientType int) interface{}
 
+// Online gradient descent interface. Is used by online training mechanism
 type OnlineGradientDescent interface {
 	ApplyDelta(m mat.Matrix, gradient mat.Matrix) *mat.Dense
 }
 
+// Batch gradient descent interface. Is used by batch training mechanism
 type BatchGradientDescent interface {
 	ApplyDelta(m mat.Matrix) *mat.Dense
 	AccumGradients(gradient mat.Matrix)

+ 5 - 8
neuralnetwork/neuralnetwork.go

@@ -268,16 +268,14 @@ func (nn *NeuralNetwork) Validate(trainer training.Trainer) (failCount, total in
 	defer nn.syncMutex.Unlock()
 	failCount = 0
 	total = 0
-	trainer.Reset()
-	for trainer.NextValidator() {
-		dataSet, expect := trainer.GetValidator()
+	for i := 0; i < trainer.ValidatorCount(); i++ {
+		dataSet, expect := trainer.GetValidator(i)
 		index, _ := nn.Predict(dataSet)
 		if expect.At(index, 0) != 1.0 {
 			failCount++
 		}
 		total++
 	}
-	trainer.Reset()
 	return
 }
 
@@ -300,9 +298,9 @@ func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
 
 func (nn *NeuralNetwork) trainOnline(trainer training.Trainer, epocs int) {
 	for t := 0; t < epocs; t++ {
-		for trainer.NextData() {
+		for i := 0; i < trainer.DataCount(); i++ {
 			nn.syncMutex.Lock()
-			dB, dW := nn.backward(trainer.GetData())
+			dB, dW := nn.backward(trainer.GetData(i))
 			for l := 1; l < nn.LayerCount; l++ {
 				bGradient, ok := nn.BGradient[l].(OnlineGradientDescent)
 				if !ok {
@@ -321,7 +319,6 @@ func (nn *NeuralNetwork) trainOnline(trainer training.Trainer, epocs int) {
 			}
 			nn.syncMutex.Unlock()
 		}
-		trainer.Reset()
 	}
 }
 
@@ -359,7 +356,7 @@ func (nn *NeuralNetwork) trainBatch(trainer training.Trainer, epocs int) {
 
 func (nn *NeuralNetwork) runBatchWorkers(threadCount int, trainer training.Trainer) (workers []*batchWorker) {
 	wg := sync.WaitGroup{}
-	chunkSize := trainer.GetDataCount() / threadCount
+	chunkSize := trainer.DataCount() / threadCount
 	workers = make([]*batchWorker, threadCount)
 	for i, _ := range workers {
 		workers[i] = newBatchWorker(nn)

+ 2 - 9
remotecontrol/remotecontrol.go

@@ -203,9 +203,8 @@ func (rw *RemoteControl) DummyStart(context.Context, *None) (*None, error) {
 		rw.UpdateState(neuralnetwork.StateLearning)
 		defer rw.UpdateState(neuralnetwork.StateIdle)
 		failCount := 0
-		trainer.Reset()
-		for trainer.NextValidator() {
-			dataSet, expect := trainer.GetValidator()
+		for i := 0; i < trainer.ValidatorCount(); i++ {
+			dataSet, expect := trainer.GetValidator(i)
 			index, _ := rw.nn.Predict(dataSet)
 			//TODO: remove this if not used for visualization
 			time.Sleep(400 * time.Millisecond)
@@ -213,16 +212,10 @@ func (rw *RemoteControl) DummyStart(context.Context, *None) (*None, error) {
 				failCount++
 				// fmt.Printf("Fail: %v, %v\n\n", trainer.ValidationIndex(), expect.At(index, 0))
 			}
-			if !trainer.NextValidator() {
-				fmt.Printf("Fail count: %v\n\n", failCount)
-				failCount = 0
-				trainer.Reset()
-			}
 		}
 
 		fmt.Printf("Fail count: %v\n\n", failCount)
 		failCount = 0
-		trainer.Reset()
 		rw.UpdateState(neuralnetwork.StateIdle)
 	}()
 

+ 68 - 87
training/mnistreader.go

@@ -27,7 +27,6 @@ package training
 
 import (
 	"encoding/binary"
-	"fmt"
 	"io"
 	"log"
 	"os"
@@ -35,12 +34,13 @@ import (
 	mat "gonum.org/v1/gonum/mat"
 )
 
-type MNISTReader struct {
-	file                      *os.File
-	resultsFile               *os.File
-	fileValidation            *os.File
-	resultsFileValidation     *os.File
-	size                      int
+type mnistReader struct {
+	dataFilename              string
+	resultsFilename           string
+	validatorFilename         string
+	validatorResultsFilename  string
+	dataCount                 int
+	validatorCount            int
 	imageSize                 int
 	buffered                  *mat.Dense
 	resultsBuffered           *mat.Dense
@@ -48,92 +48,59 @@ type MNISTReader struct {
 	resultsBufferedValidation *mat.Dense
 }
 
-func NewMNISTReader(dataFilename string, resultsFilename string) (r *MNISTReader) {
-	r = &MNISTReader{}
+func NewMNISTReader(dataFilename string, resultsFilename string, validatorFilename string, validatorResultsFilename string) (r *mnistReader) {
+	r = &mnistReader{}
 
-	var err error
-	r.file, err = os.Open(dataFilename)
-	if err != nil {
-		return nil
-	}
-
-	r.resultsFile, err = os.Open(resultsFilename)
-	if err != nil {
-		return nil
-	}
-
-	buffer := make([]byte, 16)
-	r.file.Read(buffer)
-	header := binary.BigEndian.Uint32(buffer[:4])
-	if header != 0x00000803 {
-		return nil
-	}
-	r.size = int(binary.BigEndian.Uint32(buffer[4:8]))
-	r.imageSize = int(binary.BigEndian.Uint32(buffer[8:12])) * int(binary.BigEndian.Uint32(buffer[12:16]))
-	fmt.Printf("Image size: %v\n", r.imageSize)
-	buffer = make([]byte, 8)
-	r.resultsFile.Read(buffer)
-	header = binary.BigEndian.Uint32(buffer[0:4])
-	if header != 0x00000801 {
-		return nil
-	}
-	resultsSize := int(binary.BigEndian.Uint32(buffer[4:8]))
-	if resultsSize != r.size {
-		return nil
-	}
-
-	//Separation validation part
-	r.fileValidation, err = os.Open(dataFilename)
-	if err != nil {
+	r.dataCount, r.imageSize = openFileSet(dataFilename, resultsFilename)
+	r.validatorCount, _ = openFileSet(validatorFilename, validatorResultsFilename)
+	if r.dataCount <= 0 || r.imageSize <= 0 || r.validatorCount <= 0 {
 		return nil
 	}
+	return
+}
 
-	r.resultsFileValidation, err = os.Open(resultsFilename)
-	if err != nil {
-		return nil
+func (r *mnistReader) GetData(i int) (*mat.Dense, *mat.Dense) {
+	if r.dataCount <= i {
+		return nil, nil
 	}
 
-	r.Reset()
-	return
+	return r.readData(r.dataFilename, r.resultsFilename, i)
 }
 
-func (r *MNISTReader) GetData() (*mat.Dense, *mat.Dense) {
-	return r.buffered, r.resultsBuffered
+func (r *mnistReader) DataCount() int {
+	return r.dataCount
 }
 
-func (r *MNISTReader) NextData() bool {
-	r.buffered, r.resultsBuffered = r.readNextData(r.fileValidation, r.resultsFileValidation)
-	if r.buffered != nil && r.resultsBuffered != nil {
-		return true
+func (r *mnistReader) GetValidator(i int) (data *mat.Dense, result *mat.Dense) {
+	if r.validatorCount <= i {
+		return nil, nil
 	}
-	r.Reset()
-	return false
-}
 
-func (r *MNISTReader) Reset() {
-	r.file.Seek(16, 0)
-	r.resultsFile.Seek(8, 0)
-
-	r.fileValidation.Seek(16, 0)
-	r.resultsFileValidation.Seek(8, 0)
+	return r.readData(r.validatorFilename, r.validatorResultsFilename, i)
 }
 
-func (r *MNISTReader) GetValidator() (*mat.Dense, *mat.Dense) {
-	return r.bufferedValidation, r.resultsBufferedValidation
+func (r *mnistReader) ValidatorCount() int {
+	return r.validatorCount
 }
 
-func (r *MNISTReader) NextValidator() bool {
-	r.bufferedValidation, r.resultsBufferedValidation = r.readNextData(r.fileValidation, r.resultsFileValidation)
-	if r.bufferedValidation != nil && r.resultsBufferedValidation != nil {
-		return true
+func (r *mnistReader) readData(data string, result string, i int) (buffered, resultsBuffered *mat.Dense) {
+	file, err := os.Open(data)
+	if err != nil {
+		return nil, nil
 	}
-	r.Reset()
-	return false
-}
+	defer file.Close()
+
+	resultsFile, err := os.Open(result)
+	if err != nil {
+		return nil, nil
+	}
+	defer resultsFile.Close()
+
+	file.Seek(16+int64(r.imageSize*i), 0)
+	resultsFile.Seek(8+int64(i), 0)
 
-func (r *MNISTReader) readNextData(file *os.File, resultsFile *os.File) (buffered, resultsBuffered *mat.Dense) {
 	buffer := make([]byte, r.imageSize)
-	_, err := file.Read(buffer)
+	_, err = file.Read(buffer)
 
 	if err == io.EOF {
 		return nil, nil
@@ -162,25 +129,39 @@ func (r *MNISTReader) readNextData(file *os.File, resultsFile *os.File) (buffere
 	return buffered, resultsBuffered
 }
 
-func (r *MNISTReader) GetDataByIndex(i int) (*mat.Dense, *mat.Dense) {
-	file, err := os.Open(r.file.Name())
+func openFileSet(dataFilename string, resultsFilename string) (count int, imageSize int) {
+	var err error
+	data, err := os.Open(dataFilename)
 	if err != nil {
-		return nil, nil
+		return -1, -1
 	}
-	defer file.Close()
+	defer data.Close()
 
-	resultsFile, err := os.Open(r.resultsFile.Name())
+	result, err := os.Open(resultsFilename)
 	if err != nil {
-		return nil, nil
+		return -1, -1
 	}
-	defer resultsFile.Close()
+	defer result.Close()
 
-	file.Seek(16+int64(r.imageSize*i), 0)
-	resultsFile.Seek(8+int64(i), 0)
+	buffer := make([]byte, 16)
+	data.Read(buffer)
+	header := binary.BigEndian.Uint32(buffer[:4])
+	if header != 0x00000803 {
+		return -1, -1
+	}
+	count = int(binary.BigEndian.Uint32(buffer[4:8]))
+	imageSize = int(binary.BigEndian.Uint32(buffer[8:12])) * int(binary.BigEndian.Uint32(buffer[12:16]))
 
-	return r.readNextData(file, resultsFile)
-}
+	buffer = make([]byte, 8)
+	result.Read(buffer)
+	header = binary.BigEndian.Uint32(buffer[0:4])
+	if header != 0x00000801 {
+		return -1, -1
+	}
+	resultsCount := int(binary.BigEndian.Uint32(buffer[4:8]))
+	if resultsCount != count {
+		return -1, -1
+	}
 
-func (r *MNISTReader) GetDataCount() int {
-	return r.size
+	return
 }

+ 16 - 54
training/textdatareader.go

@@ -40,23 +40,18 @@ import (
 )
 
 type TextDataReader struct {
-	dataSet         []*mat.Dense
-	result          []*mat.Dense
-	index           int
-	validationIndex int
-	validationCount int
-	mutex           *sync.Mutex
+	dataSet   []*mat.Dense
+	result    []*mat.Dense
+	dataCount int
+	mutex     *sync.Mutex
 }
 
 func NewTextDataReader(filename string, validationPart int) *TextDataReader {
 	r := &TextDataReader{
-		index:           0,
-		validationIndex: 0,
 		mutex:           &sync.Mutex{},
 	}
 	r.readData(filename)
-	r.validationCount = len(r.dataSet) / validationPart
-	r.validationIndex = len(r.dataSet) - r.validationCount
+	r.dataCount = int((float64(len(r.dataSet)) * float64(100.0 - validationPart)) / 100.0)
 	return r
 }
 func (r *TextDataReader) readData(filename string) {
@@ -146,59 +141,26 @@ func (r *TextDataReader) readData(filename string) {
 	}
 }
 
-func (r *TextDataReader) GetData() (*mat.Dense, *mat.Dense) {
-	// r.mutex.Lock()
-	// defer r.mutex.Unlock()
-	return r.dataSet[r.index], r.result[r.index]
-}
-
-func (r *TextDataReader) NextData() bool {
-	// r.mutex.Lock()
-	// defer r.mutex.Unlock()
-	if (r.index + 1) >= len(r.result)-r.validationCount {
-		r.index = 0
-		return false
-	}
-	r.index++
-
-	return true
-}
-
-func (r *TextDataReader) GetValidator() (*mat.Dense, *mat.Dense) {
-	return r.dataSet[r.validationIndex], r.result[r.validationIndex]
-}
-
-func (r *TextDataReader) NextValidator() bool {
-	if (r.validationIndex + 1) >= len(r.dataSet) {
-		r.validationIndex = len(r.dataSet) - r.validationCount
-		return false
+func (r *TextDataReader) GetData(i int) (*mat.Dense, *mat.Dense) {
+	if i >= r.dataCount {
+		return nil, nil
 	}
-	r.validationIndex++
-
-	return true
-}
 
-func (r *TextDataReader) Reset() {
-	r.index = 0
-	r.validationIndex = len(r.dataSet) - r.validationCount
-}
-
-func (r *TextDataReader) Index() int {
-	return r.index
+	return r.dataSet[i], r.result[i]
 }
 
-func (r *TextDataReader) ValidationIndex() int {
-	return r.validationIndex
+func (r *TextDataReader) DataCount() int {
+	return r.dataCount;
 }
 
-func (r *TextDataReader) GetDataByIndex(i int) (*mat.Dense, *mat.Dense) {
-	if i >= len(r.result)-r.validationCount {
+func (r *TextDataReader) GetValidator(i int) (*mat.Dense, *mat.Dense) {
+	if i >= len(r.result) - r.dataCount {
 		return nil, nil
 	}
 
-	return r.dataSet[i], r.result[i]
+	return r.dataSet[r.dataCount + i], r.result[r.dataCount + i]
 }
 
-func (r *TextDataReader) GetDataCount() int {
-	return len(r.dataSet) - r.validationCount
+func (r *TextDataReader) ValidatorCount() int {
+	return len(r.result) -  r.dataCount;
 }

+ 4 - 9
training/trainer.go

@@ -31,14 +31,9 @@ import (
 
 // Trainer is basic inteface for neuralnetwork.NeuralNetwork training and validation
 type Trainer interface {
-	GetData() (*mat.Dense, *mat.Dense)
-	NextData() bool
+	GetData(i int) (*mat.Dense, *mat.Dense)
+	DataCount() int
 
-	GetDataByIndex(i int) (*mat.Dense, *mat.Dense)
-	GetDataCount() int
-
-	GetValidator() (*mat.Dense, *mat.Dense)
-	NextValidator() bool
-
-	Reset()
+	GetValidator(i int) (*mat.Dense, *mat.Dense)
+	ValidatorCount() int
 }