main.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. package main
  2. import (
  3. "fmt"
  4. "log"
  5. "os"
  6. neuralnetwork "./neuralnetworkbase"
  7. teach "./teach"
  8. )
  9. func main() {
  10. sizes := []int{13, 14, 14, 3}
  11. nn, _ := neuralnetwork.NewNeuralNetwork(sizes, 0.1, 481)
  12. // for i := 0; i < nn.Count; i++ {
  13. // if i > 0 {
  14. // fmt.Printf("Weights before:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
  15. // fmt.Printf("Biases before:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
  16. // fmt.Printf("Z before:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
  17. // }
  18. // fmt.Printf("A before:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
  19. // }
  20. teacher := teach.NewTextDataReader("./wine.data")
  21. nn.Teach(teacher)
  22. // for i := 0; i < nn.Count; i++ {
  23. // if i > 0 {
  24. // fmt.Printf("Weights after:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
  25. // fmt.Printf("Biases after:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
  26. // fmt.Printf("Z after:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
  27. // }
  28. // fmt.Printf("A after:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
  29. // }
  30. outFile, err := os.OpenFile("./data", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
  31. if err != nil {
  32. log.Fatal(err)
  33. }
  34. defer outFile.Close()
  35. nn.SaveState(outFile)
  36. outFile.Close()
  37. failCount := 0
  38. teacher.Reset()
  39. for teacher.NextData() {
  40. dataSet, expect := teacher.GetData()
  41. index, _ := nn.Predict(dataSet)
  42. if expect.At(index, 0) != 1.0 {
  43. failCount++
  44. fmt.Printf("Fail: %v, %v\n\n", teacher.Index(), expect.At(index, 0))
  45. }
  46. }
  47. fmt.Printf("Fail count: %v\n\n", failCount)
  48. nn = &neuralnetwork.NeuralNetwork{}
  49. inFile, err := os.Open("./data")
  50. if err != nil {
  51. log.Fatal(err)
  52. }
  53. defer inFile.Close()
  54. nn.LoadState(inFile)
  55. inFile.Close()
  56. failCount = 0
  57. teacher.Reset()
  58. for teacher.NextData() {
  59. dataSet, expect := teacher.GetData()
  60. index, _ := nn.Predict(dataSet)
  61. if expect.At(index, 0) != 1.0 {
  62. failCount++
  63. fmt.Printf("Fail: %v, %v\n\n", teacher.Index(), expect.At(index, 0))
  64. }
  65. }
  66. fmt.Printf("Fail count: %v\n\n", failCount)
  67. }