瀏覽代碼

Add simple unittests for neuralnetwork

- Add simple unittests
- Add required checks to NeuralNetwork
Alexey Edelev 5 年之前
父節點
當前提交
c7ffa9f4da
共有 4 個文件被更改,包括 112 次插入7 次删除
  1. 2 2
      build.sh
  2. 1 1
      neuralnetwork/main.go
  3. 69 0
      neuralnetwork/neuralnetwork_test.go
  4. 40 4
      neuralnetwork/neuralnetworkbase/neuralnetwork.go

+ 2 - 2
build.sh

@@ -6,5 +6,5 @@ cd neuralnetwork
 
 go get -v
 go build -o $GOBIN/neuralnetwork
-#go test -v
-#go test -cover
+go test -v
+# go test -cover

+ 1 - 1
neuralnetwork/main.go

@@ -9,7 +9,7 @@ import (
 
 func main() {
 	sizes := []int{4, 8, 8, 3}
-	nn := neuralnetwork.NewNeuralNetwork(sizes, 0.1, 481)
+	nn, _ := neuralnetwork.NewNeuralNetwork(sizes, 0.1, 481)
 
 	// for i := 0; i < nn.Count; i++ {
 	// 	if i > 0 {

+ 69 - 0
neuralnetwork/neuralnetwork_test.go

@@ -0,0 +1,69 @@
+package main
+
+import (
+	"testing"
+
+	"gonum.org/v1/gonum/mat"
+
+	neuralnetwork "./neuralnetworkbase"
+)
+
+func TestNewNeuralNetwork(t *testing.T) {
+	nn, err := neuralnetwork.NewNeuralNetwork([]int{}, 0.1, 500)
+	if nn != nil || err == nil {
+		t.Error("nn initialized, but shouldn't ", err)
+	}
+
+	nn, err = neuralnetwork.NewNeuralNetwork([]int{0, 0, 0, 0}, 0.1, 500)
+	if nn != nil || err == nil {
+		t.Error("nn initialized, but shouldn't ", err)
+	}
+
+	nn, err = neuralnetwork.NewNeuralNetwork([]int{1, 1, 1, 1}, 0.1, 500)
+	if nn != nil || err == nil {
+		t.Error("nn initialized, but shouldn't ", err)
+	}
+
+	nn, err = neuralnetwork.NewNeuralNetwork([]int{5, 5}, 0.1, 500)
+	if nn != nil || err == nil {
+		t.Error("nn initialized, but shouldn't ", err)
+	}
+
+	nn, err = neuralnetwork.NewNeuralNetwork([]int{5, 1, 5, 5}, 0.1, 500)
+	if nn != nil || err == nil {
+		t.Error("nn initialized, but shouldn't ", err)
+	}
+
+	nn, err = neuralnetwork.NewNeuralNetwork([]int{5, 4, 4, 5}, 0.1, 500)
+	if nn == nil || err != nil {
+		t.Error("nn is not initialized, but should be ", err)
+	}
+}
+
+func TestNeuralNetworkPredict(t *testing.T) {
+	nn, _ := neuralnetwork.NewNeuralNetwork([]int{3, 4, 4, 2}, 0.1, 500)
+
+	aIn := &mat.Dense{}
+	index, max := nn.Predict(aIn)
+	if index != -1 || max != 0.0 {
+		t.Error("Prediction when empty aIn shouldn't be possibe but predicted", index, max)
+	}
+
+	aIn = mat.NewDense(2, 1, []float64{0.1, 0.2})
+	index, max = nn.Predict(aIn)
+	if index != -1 || max != 0.0 {
+		t.Error("Prediction aIn has invalid size shouldn't be possibe but predicted", index, max)
+	}
+
+	aIn = mat.NewDense(3, 1, []float64{0.1, 0.2, 0.3})
+	index, max = nn.Predict(aIn)
+	if index == -1 || max == 0.0 {
+		t.Error("Prediction of aIn valid size should be predicted", index, max)
+	}
+
+	aIn = mat.NewDense(4, 1, []float64{0.1, 0.2, 0.3, 0.4})
+	index, max = nn.Predict(aIn)
+	if index != -1 || max != 0.0 {
+		t.Error("Prediction aIn has invalid size shouldn't be possibe but predicted", index, max)
+	}
+}

+ 40 - 4
neuralnetwork/neuralnetworkbase/neuralnetwork.go

@@ -26,6 +26,9 @@
 package neuralnetworkbase
 
 import (
+	"errors"
+	"fmt"
+
 	teach "../teach"
 	mat "gonum.org/v1/gonum/mat"
 )
@@ -91,10 +94,37 @@ type NeuralNetwork struct {
 	trainingCycles int
 }
 
-func NewNeuralNetwork(Sizes []int, nu float64, trainingCycles int) (nn *NeuralNetwork) {
+func NewNeuralNetwork(sizes []int, nu float64, trainingCycles int) (nn *NeuralNetwork, err error) {
+	err = nil
+	if len(sizes) < 3 {
+		fmt.Printf("Invalid network configuration: %v\n", sizes)
+		return nil, errors.New("Invalid network configuration: %v\n")
+	}
+
+	for i := 0; i < len(sizes); i++ {
+		if sizes[i] < 2 {
+			fmt.Printf("Invalid network configuration: %v\n", sizes)
+			return nil, errors.New("Invalid network configuration: %v\n")
+		}
+	}
+
+	if nu <= 0.0 || nu > 1.0 {
+		fmt.Printf("Invalid η value: %v\n", nu)
+		return nil, errors.New("Invalid η value: %v\n")
+	}
+
+	if trainingCycles <= 0 {
+		fmt.Printf("Invalid training cycles number: %v\n", trainingCycles)
+		return nil, errors.New("Invalid training cycles number: %v\n")
+	}
+
+	if trainingCycles < 100 {
+		fmt.Println("Training cycles number probably is too small")
+	}
+
 	nn = &NeuralNetwork{}
-	nn.Sizes = Sizes
-	nn.Count = len(Sizes)
+	nn.Sizes = sizes
+	nn.Count = len(sizes)
 	nn.Weights = make([]*mat.Dense, nn.Count)
 	nn.Biases = make([]*mat.Dense, nn.Count)
 	nn.A = make([]*mat.Dense, nn.Count)
@@ -110,9 +140,15 @@ func NewNeuralNetwork(Sizes []int, nu float64, trainingCycles int) (nn *NeuralNe
 }
 
 func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
+	r, _ := aIn.Dims()
+	if r != nn.Sizes[0] {
+		fmt.Printf("Invalid rows number of input matrix size: %v\n", r)
+		return -1, 0.0
+	}
+
 	nn.forward(aIn)
 	result := nn.result()
-	r, _ := result.Dims()
+	r, _ = result.Dims()
 	max = 0.0
 	maxIndex = 0
 	for i := 0; i < r; i++ {