main.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. package main
  2. import (
  3. "fmt"
  4. "log"
  5. "os"
  6. neuralnetwork "./neuralnetworkbase"
  7. teach "./teach"
  8. )
  9. func main() {
  10. sizes := []int{13, 16, 16, 3}
  11. nn, _ := neuralnetwork.NewNeuralNetwork(sizes, neuralnetwork.NewRPropInitializer(neuralnetwork.RPropConfig{
  12. NuPlus: 1.2,
  13. NuMinus: 0.5,
  14. DeltaMax: 50.0,
  15. DeltaMin: 0.000001,
  16. }))
  17. // inFile, err := os.Open("./networkstate")
  18. // if err != nil {
  19. // log.Fatal(err)
  20. // }
  21. // defer inFile.Close()
  22. // nn.LoadState(inFile)
  23. // nn, _ := neuralnetwork.NewNeuralNetwork(sizes, neuralnetwork.NewBackPropInitializer(0.1))
  24. // for i := 0; i < nn.Count; i++ {
  25. // if i > 0 {
  26. // fmt.Printf("Weights before:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
  27. // fmt.Printf("Biases before:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
  28. // fmt.Printf("Z before:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
  29. // }
  30. // fmt.Printf("A before:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
  31. // }
  32. // teacher := teach.NewMNISTReader("./minst.data", "./mnist.labels")
  33. teacher := teach.NewTextDataReader("wine.data", 7)
  34. nn.Teach(teacher, 500)
  35. // for i := 0; i < nn.Count; i++ {
  36. // if i > 0 {
  37. // fmt.Printf("Weights after:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
  38. // fmt.Printf("Biases after:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
  39. // fmt.Printf("Z after:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
  40. // }
  41. // fmt.Printf("A after:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
  42. // }
  43. outFile, err := os.OpenFile("./data", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
  44. if err != nil {
  45. log.Fatal(err)
  46. }
  47. defer outFile.Close()
  48. nn.SaveState(outFile)
  49. outFile.Close()
  50. failCount := 0
  51. teacher.Reset()
  52. for teacher.NextValidator() {
  53. dataSet, expect := teacher.GetValidator()
  54. index, _ := nn.Predict(dataSet)
  55. if expect.At(index, 0) != 1.0 {
  56. failCount++
  57. // fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
  58. }
  59. }
  60. fmt.Printf("Fail count: %v\n\n", failCount)
  61. // nn = &neuralnetwork.NeuralNetwork{}
  62. // inFile, err := os.Open("./data")
  63. // if err != nil {
  64. // log.Fatal(err)
  65. // }
  66. // defer inFile.Close()
  67. // nn.LoadState(inFile)
  68. // inFile.Close()
  69. // failCount = 0
  70. // teacher.Reset()
  71. // for teacher.NextValidator() {
  72. // dataSet, expect := teacher.GetValidator()
  73. // index, _ := nn.Predict(dataSet)
  74. // if expect.At(index, 0) != 1.0 {
  75. // failCount++
  76. // // fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
  77. // }
  78. // }
  79. // fmt.Printf("Fail count: %v\n\n", failCount)
  80. }