main.go 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. package main
  2. import (
  3. "fmt"
  4. "log"
  5. "os"
  6. neuralnetwork "./neuralnetworkbase"
  7. teach "./teach"
  8. )
  9. func main() {
  10. sizes := []int{13, 14, 14, 3}
  11. nn, _ := neuralnetwork.NewNeuralNetwork(sizes, 200, neuralnetwork.NewPlusRPropInitializer(neuralnetwork.RPropConfig{
  12. NuPlus: 1.2,
  13. NuMinus: 0.8,
  14. DeltaMax: 50.0,
  15. DeltaMin: 0.000001,
  16. }))
  17. // nn, _ := neuralnetwork.NewNeuralNetwork(sizes, 200, neuralnetwork.NewBackPropInitializer(0.1))
  18. // for i := 0; i < nn.Count; i++ {
  19. // if i > 0 {
  20. // fmt.Printf("Weights before:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
  21. // fmt.Printf("Biases before:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
  22. // fmt.Printf("Z before:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
  23. // }
  24. // fmt.Printf("A before:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
  25. // }
  26. teacher := teach.NewTextDataReader("./wine.data", 5)
  27. nn.Teach(teacher)
  28. // for i := 0; i < nn.Count; i++ {
  29. // if i > 0 {
  30. // fmt.Printf("Weights after:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
  31. // fmt.Printf("Biases after:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
  32. // fmt.Printf("Z after:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
  33. // }
  34. // fmt.Printf("A after:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
  35. // }
  36. outFile, err := os.OpenFile("./data", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
  37. if err != nil {
  38. log.Fatal(err)
  39. }
  40. defer outFile.Close()
  41. nn.SaveState(outFile)
  42. outFile.Close()
  43. failCount := 0
  44. teacher.Reset()
  45. for teacher.NextValidator() {
  46. dataSet, expect := teacher.GetValidator()
  47. index, _ := nn.Predict(dataSet)
  48. if expect.At(index, 0) != 1.0 {
  49. failCount++
  50. fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
  51. }
  52. }
  53. fmt.Printf("Fail count: %v\n\n", failCount)
  54. nn = &neuralnetwork.NeuralNetwork{}
  55. inFile, err := os.Open("./data")
  56. if err != nil {
  57. log.Fatal(err)
  58. }
  59. defer inFile.Close()
  60. nn.LoadState(inFile)
  61. inFile.Close()
  62. failCount = 0
  63. teacher.Reset()
  64. for teacher.NextValidator() {
  65. dataSet, expect := teacher.GetValidator()
  66. index, _ := nn.Predict(dataSet)
  67. if expect.At(index, 0) != 1.0 {
  68. failCount++
  69. fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
  70. }
  71. }
  72. fmt.Printf("Fail count: %v\n\n", failCount)
  73. }