textdatareader.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. package teach
  2. import (
  3. "bufio"
  4. "fmt"
  5. "log"
  6. "os"
  7. "strconv"
  8. "strings"
  9. mat "gonum.org/v1/gonum/mat"
  10. )
  11. type TextDataReader struct {
  12. dataSet []*mat.Dense
  13. result []*mat.Dense
  14. index int
  15. }
  16. func NewTextDataReader(filename string) *TextDataReader {
  17. r := &TextDataReader{
  18. index: 0,
  19. }
  20. r.readData(filename)
  21. return r
  22. }
  23. func (r *TextDataReader) readData(filename string) {
  24. inputFile, err := os.Open(filename)
  25. if err != nil {
  26. log.Fatal(err)
  27. }
  28. defer inputFile.Close()
  29. scanner := bufio.NewScanner(inputFile)
  30. scanner.Split(bufio.ScanLines)
  31. max := []float64{0.0, 0.0, 0.0, 0.0}
  32. for scanner.Scan() {
  33. dataLine := scanner.Text()
  34. data := strings.Split(dataLine, ",")
  35. if len(data) < 5 {
  36. fmt.Printf("Garbage record: %s\n", dataLine)
  37. continue
  38. }
  39. var dataRaw []float64
  40. for i := 0; i < 4; i++ {
  41. val, err := strconv.ParseFloat(data[i], 64)
  42. if err != nil {
  43. break
  44. }
  45. dataRaw = append(dataRaw, val)
  46. if max[i] < val {
  47. max[i] = val
  48. }
  49. }
  50. if len(dataRaw) < 4 {
  51. fmt.Printf("Garbage record: %s\n", dataLine)
  52. continue
  53. }
  54. r.dataSet = append(r.dataSet, mat.NewDense(4, 1, dataRaw))
  55. switch data[4] {
  56. case "Iris-setosa":
  57. r.result = append(r.result, mat.NewDense(3, 1, []float64{1.0, 0.0, 0.0}))
  58. case "Iris-versicolor":
  59. r.result = append(r.result, mat.NewDense(3, 1, []float64{0.0, 1.0, 0.0}))
  60. case "Iris-virginica":
  61. r.result = append(r.result, mat.NewDense(3, 1, []float64{0.0, 0.0, 1.0}))
  62. }
  63. }
  64. //normalize
  65. for i := 0; i < len(r.dataSet); i++ {
  66. r.dataSet[i].Apply(func(r, _ int, val float64) float64 {
  67. return val / max[r]
  68. }, r.dataSet[i])
  69. }
  70. }
  71. func (r *TextDataReader) GetData() *mat.Dense {
  72. return r.dataSet[r.index]
  73. }
  74. func (r *TextDataReader) GetExpect() *mat.Dense {
  75. return r.result[r.index]
  76. }
  77. func (r *TextDataReader) Next() bool {
  78. r.index++
  79. if r.index >= len(r.result) {
  80. r.index = 0
  81. return false
  82. }
  83. return true
  84. }
  85. func (r *TextDataReader) Reset() {
  86. r.index = 0
  87. }
  88. func (r *TextDataReader) Index() int {
  89. return r.index
  90. }