main.go 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. package main
  2. import (
  3. "fmt"
  4. "log"
  5. "os"
  6. neuralnetwork "./neuralnetworkbase"
  7. teach "./teach"
  8. )
  9. func main() {
  10. sizes := []int{13, 14, 14, 3}
  11. var nn neuralnetwork.NeuralNetwork
  12. nn, _ = neuralnetwork.NewBackProp(sizes, 0.1, 481)
  13. // for i := 0; i < nn.Count; i++ {
  14. // if i > 0 {
  15. // fmt.Printf("Weights before:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
  16. // fmt.Printf("Biases before:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
  17. // fmt.Printf("Z before:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
  18. // }
  19. // fmt.Printf("A before:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
  20. // }
  21. teacher := teach.NewTextDataReader("./wine.data")
  22. nn.Teach(teacher)
  23. // for i := 0; i < nn.Count; i++ {
  24. // if i > 0 {
  25. // fmt.Printf("Weights after:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
  26. // fmt.Printf("Biases after:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
  27. // fmt.Printf("Z after:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
  28. // }
  29. // fmt.Printf("A after:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
  30. // }
  31. outFile, err := os.OpenFile("./data", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
  32. if err != nil {
  33. log.Fatal(err)
  34. }
  35. defer outFile.Close()
  36. nn.SaveState(outFile)
  37. outFile.Close()
  38. failCount := 0
  39. teacher.Reset()
  40. for teacher.NextData() {
  41. dataSet, expect := teacher.GetData()
  42. index, _ := nn.Predict(dataSet)
  43. if expect.At(index, 0) != 1.0 {
  44. failCount++
  45. fmt.Printf("Fail: %v, %v\n\n", teacher.Index(), expect.At(index, 0))
  46. }
  47. }
  48. fmt.Printf("Fail count: %v\n\n", failCount)
  49. nn = &neuralnetwork.BackProp{}
  50. inFile, err := os.Open("./data")
  51. if err != nil {
  52. log.Fatal(err)
  53. }
  54. defer inFile.Close()
  55. nn.LoadState(inFile)
  56. inFile.Close()
  57. failCount = 0
  58. teacher.Reset()
  59. for teacher.NextData() {
  60. dataSet, expect := teacher.GetData()
  61. index, _ := nn.Predict(dataSet)
  62. if expect.At(index, 0) != 1.0 {
  63. failCount++
  64. fmt.Printf("Fail: %v, %v\n\n", teacher.Index(), expect.At(index, 0))
  65. }
  66. }
  67. fmt.Printf("Fail count: %v\n\n", failCount)
  68. }