main.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. package main
  2. import (
  3. "fmt"
  4. "log"
  5. "os"
  6. "time"
  7. neuralnetwork "./neuralnetworkbase"
  8. remotecontrol "./remotecontrol"
  9. teach "./teach"
  10. )
  11. func main() {
  12. sizes := []int{13, 8, 12, 3}
  13. nn, _ := neuralnetwork.NewNeuralNetwork(sizes, neuralnetwork.NewRPropInitializer(neuralnetwork.RPropConfig{
  14. NuPlus: 1.2,
  15. NuMinus: 0.5,
  16. DeltaMax: 50.0,
  17. DeltaMin: 0.000001,
  18. }))
  19. rc := &remotecontrol.RemoteControl{}
  20. nn.SetStateWatcher(rc)
  21. // inFile, err := os.Open("./networkstate")
  22. // if err != nil {
  23. // log.Fatal(err)
  24. // }
  25. // defer inFile.Close()
  26. // nn.LoadState(inFile)
  27. // nn, _ := neuralnetwork.NewNeuralNetwork(sizes, neuralnetwork.NewBackPropInitializer(0.1))
  28. // for i := 0; i < nn.Count; i++ {
  29. // if i > 0 {
  30. // fmt.Printf("Weights before:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
  31. // fmt.Printf("Biases before:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
  32. // fmt.Printf("Z before:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
  33. // }
  34. // fmt.Printf("A before:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
  35. // }
  36. go func() {
  37. // teacher := teach.NewMNISTReader("./minst.data", "./mnist.labels")
  38. teacher := teach.NewTextDataReader("wine.data", 5)
  39. nn.Teach(teacher, 500)
  40. // for i := 0; i < nn.Count; i++ {
  41. // if i > 0 {
  42. // fmt.Printf("Weights after:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
  43. // fmt.Printf("Biases after:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
  44. // fmt.Printf("Z after:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
  45. // }
  46. // fmt.Printf("A after:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
  47. // }
  48. outFile, err := os.OpenFile("./data", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
  49. if err != nil {
  50. log.Fatal(err)
  51. }
  52. defer outFile.Close()
  53. nn.SaveState(outFile)
  54. outFile.Close()
  55. time.Sleep(5 * time.Second)
  56. failCount := 0
  57. teacher.Reset()
  58. for true {
  59. if !teacher.NextValidator() {
  60. fmt.Printf("Fail count: %v\n\n", failCount)
  61. failCount = 0
  62. teacher.Reset()
  63. }
  64. dataSet, expect := teacher.GetValidator()
  65. index, _ := nn.Predict(dataSet)
  66. //TODO: remove this is not used for visualization
  67. time.Sleep(400 * time.Millisecond)
  68. if expect.At(index, 0) != 1.0 {
  69. failCount++
  70. // fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
  71. }
  72. }
  73. }()
  74. // nn = &neuralnetwork.NeuralNetwork{}
  75. // inFile, err := os.Open("./data")
  76. // if err != nil {
  77. // log.Fatal(err)
  78. // }
  79. // defer inFile.Close()
  80. // nn.LoadState(inFile)
  81. // inFile.Close()
  82. // failCount = 0
  83. // teacher.Reset()
  84. // for teacher.NextValidator() {
  85. // dataSet, expect := teacher.GetValidator()
  86. // index, _ := nn.Predict(dataSet)
  87. // if expect.At(index, 0) != 1.0 {
  88. // failCount++
  89. // // fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
  90. // }
  91. // }
  92. // fmt.Printf("Fail count: %v\n\n", failCount)
  93. rc.Run()
  94. }