Ver código fonte

Fix issues in training

- Fix deadlock in validation
- Fix MNIST trainer
Alexey Edelev 5 anos atrás
pai
commit
ed798ad9e5
2 arquivos alterados com 11 adições e 2 exclusões
  1. 3 2
      neuralnetwork/neuralnetwork.go
  2. 8 0
      training/mnistreader.go

+ 3 - 2
neuralnetwork/neuralnetwork.go

@@ -268,8 +268,6 @@ func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
 // provided by training.Trainer passed as argument.
 // Returns count of failure predictions and total amount of verified samples.
 func (nn *NeuralNetwork) Validate(trainer training.Trainer) (failCount, total int) {
-	nn.syncMutex.Lock()
-	defer nn.syncMutex.Unlock()
 	failCount = 0
 	total = trainer.ValidatorCount()
 	for i := 0; i < trainer.ValidatorCount(); i++ {
@@ -279,6 +277,9 @@ func (nn *NeuralNetwork) Validate(trainer training.Trainer) (failCount, total in
 			failCount++
 		}
 	}
+
+	nn.syncMutex.Lock()
+	defer nn.syncMutex.Unlock()
 	if nn.watcher != nil {
 		if nn.watcher.GetSubscriptionFeatures().Has(ValidationSubscription) {
 			nn.watcher.UpdateValidation(total, failCount)

+ 8 - 0
training/mnistreader.go

@@ -53,6 +53,10 @@ func NewMNISTReader(dataFilename string, resultsFilename string, validatorFilena
 
 	r.dataCount, r.imageSize = openFileSet(dataFilename, resultsFilename)
 	r.validatorCount, _ = openFileSet(validatorFilename, validatorResultsFilename)
+	r.dataFilename = dataFilename
+	r.resultsFilename = resultsFilename
+	r.validatorFilename = validatorFilename
+	r.validatorResultsFilename = validatorResultsFilename
 	if r.dataCount <= 0 || r.imageSize <= 0 || r.validatorCount <= 0 {
 		return nil
 	}
@@ -86,12 +90,14 @@ func (r *mnistReader) ValidatorCount() int {
 func (r *mnistReader) readData(data string, result string, i int) (buffered, resultsBuffered *mat.Dense) {
 	file, err := os.Open(data)
 	if err != nil {
+		log.Fatalf("Could not open data file %v\n", data)
 		return nil, nil
 	}
 	defer file.Close()
 
 	resultsFile, err := os.Open(result)
 	if err != nil {
+		log.Fatalf("Could not open result file %v\n", result)
 		return nil, nil
 	}
 	defer resultsFile.Close()
@@ -103,9 +109,11 @@ func (r *mnistReader) readData(data string, result string, i int) (buffered, res
 	_, err = file.Read(buffer)
 
 	if err == io.EOF {
+		log.Fatal("EOF reached but shouldn't\n")
 		return nil, nil
 	} else if err != nil {
 		log.Fatal("File read error\n")
+		return nil, nil
 	}
 
 	values := make([]float64, r.imageSize)