main.go 818 B

123456789101112131415161718192021222324252627282930313233
  1. package main
  2. import (
  3. "fmt"
  4. "math/rand"
  5. neuralnetwork "./neuralnetworkbase"
  6. mat "gonum.org/v1/gonum/mat"
  7. )
  8. func main() {
  9. sizes := []int{3, 2, 2}
  10. nn := neuralnetwork.NewNeuralNetwork(sizes)
  11. data := make([]float64, sizes[0])
  12. for i := range data {
  13. data[i] = rand.Float64()
  14. }
  15. aIn := mat.NewDense(sizes[0], 1, data)
  16. max, index := nn.Predict(aIn)
  17. for i := 0; i < nn.Count; i++ {
  18. if i > 0 {
  19. fmt.Printf("Weights:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
  20. fmt.Printf("Biases:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
  21. fmt.Printf("Z:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
  22. }
  23. fmt.Printf("A:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
  24. }
  25. fmt.Printf("Resul: %v, %v\n\n", index, max)
  26. }