Forráskód Böngészése

Add documentation for NeuralNetwork

Alexey Edelev 5 éve
szülő
commit
e98f6245c0
1 módosított fájl, 36 hozzáadás és 7 törlés
  1. 36 7
      neuralnetwork/neuralnetwork/neuralnetwork.go

+ 36 - 7
neuralnetwork/neuralnetwork/neuralnetwork.go

@@ -39,14 +39,14 @@ import (
 	mat "gonum.org/v1/gonum/mat"
 )
 
-// NeuralNetwork is simple neural network implementation
+// NeuralNetwork is artificial neural network implementation
 //
 // Resources:
 // http://neuralnetworksanddeeplearning.com
 // https://www.youtube.com/watch?v=fNk_zzaMoSs
 // http://www.inf.fu-berlin.de/lehre/WS06/Musterererkennung/Paper/rprop.pdf
 //
-// Matrix: A
+// Matrix: A (local matrices used in forward and backward methods)
 // Description: A is set of calculated neuron activations after sigmoid correction
 // Format:    0          l           L
 //         ⎡A[0] ⎤ ... ⎡A[0] ⎤ ... ⎡A[0] ⎤
@@ -58,7 +58,7 @@ import (
 // Where s = Sizes[l] - Neural network layer size
 //       L = len(Sizes) - Number of neural network layers
 //
-// Matrix: Z
+// Matrix: Z (local matrices used in forward and backward methods)
 // Description: Z is set of calculated raw neuron activations
 // Format:    0          l           L
 //         ⎡Z[0] ⎤ ... ⎡Z[0] ⎤ ... ⎡Z[0] ⎤
@@ -111,6 +111,12 @@ type NeuralNetwork struct {
 	syncMutex                  *sync.Mutex
 }
 
+// NewNeuralNetwork construction method that initializes new NeuralNetwork based
+// on provided list of layer sizes and GradientDescentInitializer that used for
+// backpropagation mechanism.
+// If gradientDescentInitializer is not provided (is nil) backpropagation won't
+// be possible. Common usecase when it's used is natural selection and genetic
+// training.
 func NewNeuralNetwork(sizes []int, gradientDescentInitializer GradientDescentInitializer) (nn *NeuralNetwork, err error) {
 	err = nil
 	if len(sizes) < 3 {
@@ -148,6 +154,9 @@ func NewNeuralNetwork(sizes []int, gradientDescentInitializer GradientDescentIni
 	return
 }
 
+// Copy makes complete copy of NeuralNetwork data. Output network has the same
+// weights and biases values and but might be used independend of original one,
+// e.g. in separate goroutine
 func (nn *NeuralNetwork) Copy() (outNN *NeuralNetwork) {
 	nn.syncMutex.Lock()
 	defer nn.syncMutex.Unlock()
@@ -173,6 +182,8 @@ func (nn *NeuralNetwork) Copy() (outNN *NeuralNetwork) {
 	return
 }
 
+// Reset resets network state to intial/random one with specified in argument
+// layers configuration
 func (nn *NeuralNetwork) Reset(sizes []int) (err error) {
 	nn.syncMutex.Lock()
 	defer nn.syncMutex.Unlock()
@@ -209,6 +220,9 @@ func (nn *NeuralNetwork) Reset(sizes []int) (err error) {
 	return
 }
 
+// SetStateWatcher setups state watcher for NeuralNetwork. StateWatcher is common
+// interface that collects data about NeuralNetwork behaivor. If not specified (is
+// set to nil) NeuralNetwork will ignore StateWatcher interations
 func (nn *NeuralNetwork) SetStateWatcher(watcher StateWatcher) {
 	nn.watcher = watcher
 	if watcher != nil {
@@ -217,6 +231,8 @@ func (nn *NeuralNetwork) SetStateWatcher(watcher StateWatcher) {
 	}
 }
 
+// Predict method invokes prediction based on input activations provided in argument.
+// Returns index of best element in output activation matrix and its value
 func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
 	nn.syncMutex.Lock()
 	defer nn.syncMutex.Unlock()
@@ -244,6 +260,9 @@ func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
 	return
 }
 
+// Validate runs basic network validation/verification based on validation data that
+// provided by training.Trainer passed as argument.
+// Returns count of failure predictions and total amount of verified samples.
 func (nn *NeuralNetwork) Validate(trainer training.Trainer) (failCount, total int) {
 	nn.syncMutex.Lock()
 	defer nn.syncMutex.Unlock()
@@ -262,21 +281,24 @@ func (nn *NeuralNetwork) Validate(trainer training.Trainer) (failCount, total in
 	return
 }
 
+// Train is common training function that invokes one of training methods depends on
+// gradient descent used buy NeuralNetwork. training.Trainer passed as argument used
+// to get training data. Training loops are limited buy number of epocs
 func (nn *NeuralNetwork) Train(trainer training.Trainer, epocs int) {
 	if nn.watcher != nil {
 		nn.watcher.UpdateState(StateLearning)
 		defer nn.watcher.UpdateState(StateIdle)
 	}
 	if _, ok := nn.WGradient[nn.LayerCount-1].(OnlineGradientDescent); ok {
-		nn.TrainOnline(trainer, epocs)
+		nn.trainOnline(trainer, epocs)
 	} else if _, ok := nn.WGradient[nn.LayerCount-1].(BatchGradientDescent); ok {
-		nn.TrainBatch(trainer, epocs)
+		nn.trainBatch(trainer, epocs)
 	} else {
 		panic("Invalid gradient descent type")
 	}
 }
 
-func (nn *NeuralNetwork) TrainOnline(trainer training.Trainer, epocs int) {
+func (nn *NeuralNetwork) trainOnline(trainer training.Trainer, epocs int) {
 	for t := 0; t < epocs; t++ {
 		for trainer.NextData() {
 			nn.syncMutex.Lock()
@@ -303,7 +325,7 @@ func (nn *NeuralNetwork) TrainOnline(trainer training.Trainer, epocs int) {
 	}
 }
 
-func (nn *NeuralNetwork) TrainBatch(trainer training.Trainer, epocs int) {
+func (nn *NeuralNetwork) trainBatch(trainer training.Trainer, epocs int) {
 	fmt.Printf("Start training in %v threads\n", runtime.NumCPU())
 	for t := 0; t < epocs; t++ {
 		batchWorkers := nn.runBatchWorkers(runtime.NumCPU(), trainer)
@@ -352,6 +374,9 @@ func (nn *NeuralNetwork) runBatchWorkers(threadCount int, trainer training.Train
 	return
 }
 
+// SaveState saves state of NeuralNetwork to io.Writer. It's usefull to keep training results
+// between NeuralNetwork "power cycles" or to share traing results between clustered neural
+// network hosts
 func (nn *NeuralNetwork) SaveState(writer io.Writer) {
 	nn.syncMutex.Lock()
 	defer nn.syncMutex.Unlock()
@@ -384,6 +409,7 @@ func (nn *NeuralNetwork) SaveState(writer io.Writer) {
 	}
 }
 
+// SaveStateToFile saves NeuralNetwork state to file by specific filePath
 func (nn *NeuralNetwork) SaveStateToFile(filePath string) {
 	outFile, err := os.OpenFile(filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
 	check(err)
@@ -391,6 +417,8 @@ func (nn *NeuralNetwork) SaveStateToFile(filePath string) {
 	nn.SaveState(outFile)
 }
 
+// LoadState loads NeuralNetwork state from io.Reader. All existing data in NeuralNetwork
+// will be rewritten buy this method, including layers configuration and weights and biases
 func (nn *NeuralNetwork) LoadState(reader io.Reader) {
 	nn.syncMutex.Lock()
 	defer nn.syncMutex.Unlock()
@@ -432,6 +460,7 @@ func (nn *NeuralNetwork) LoadState(reader io.Reader) {
 	// fmt.Printf("\nLoadState end\n")
 }
 
+// LoadStateFromFile loads NeuralNetwork state from file by specific filePath
 func (nn *NeuralNetwork) LoadStateFromFile(filePath string) {
 	inFile, err := os.Open(filePath)
 	check(err)