|
@@ -29,22 +29,27 @@ import (
|
|
|
"bufio"
|
|
|
"fmt"
|
|
|
"log"
|
|
|
+ "math/rand"
|
|
|
"os"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
+ "time"
|
|
|
|
|
|
mat "gonum.org/v1/gonum/mat"
|
|
|
)
|
|
|
|
|
|
type TextDataReader struct {
|
|
|
- dataSet []*mat.Dense
|
|
|
- result []*mat.Dense
|
|
|
- index int
|
|
|
+ dataSet []*mat.Dense
|
|
|
+ result []*mat.Dense
|
|
|
+ index int
|
|
|
+ validationIndex int
|
|
|
+ validationCount int
|
|
|
}
|
|
|
|
|
|
func NewTextDataReader(filename string) *TextDataReader {
|
|
|
r := &TextDataReader{
|
|
|
- index: 0,
|
|
|
+ index: 0,
|
|
|
+ validationIndex: 0,
|
|
|
}
|
|
|
r.readData(filename)
|
|
|
|
|
@@ -127,28 +132,50 @@ func (r *TextDataReader) readData(filename string) {
|
|
|
return val / max[r]
|
|
|
}, r.dataSet[i])
|
|
|
}
|
|
|
-}
|
|
|
|
|
|
-func (r *TextDataReader) GetData() *mat.Dense {
|
|
|
- return r.dataSet[r.index]
|
|
|
+ rand.Seed(time.Now().UnixNano())
|
|
|
+ for k := 0; k < 5; k++ {
|
|
|
+ rand.Shuffle(len(r.dataSet), func(i, j int) {
|
|
|
+ r.result[i], r.result[j] = r.result[j], r.result[i]
|
|
|
+ r.dataSet[i], r.dataSet[j] = r.dataSet[j], r.dataSet[i]
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ r.validationCount = 0
|
|
|
+ r.validationIndex = len(r.dataSet) - r.validationCount
|
|
|
}
|
|
|
|
|
|
-func (r *TextDataReader) GetExpect() *mat.Dense {
|
|
|
- return r.result[r.index]
|
|
|
+func (r *TextDataReader) GetData() (*mat.Dense, *mat.Dense) {
|
|
|
+ return r.dataSet[r.index], r.result[r.index]
|
|
|
}
|
|
|
|
|
|
-func (r *TextDataReader) Next() bool {
|
|
|
- r.index++
|
|
|
- if r.index >= len(r.result) {
|
|
|
+func (r *TextDataReader) NextData() bool {
|
|
|
+ if (r.index + 1) >= len(r.result)-r.validationCount {
|
|
|
r.index = 0
|
|
|
return false
|
|
|
}
|
|
|
+ r.index++
|
|
|
+
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
+func (r *TextDataReader) GetValidator() (*mat.Dense, *mat.Dense) {
|
|
|
+ return r.dataSet[r.validationIndex], r.result[r.validationIndex]
|
|
|
+}
|
|
|
+
|
|
|
+func (r *TextDataReader) NextValidator() bool {
|
|
|
+ if (r.validationIndex + 1) >= len(r.dataSet) {
|
|
|
+ r.validationIndex = len(r.dataSet) - r.validationCount
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ r.validationIndex++
|
|
|
|
|
|
return true
|
|
|
}
|
|
|
|
|
|
func (r *TextDataReader) Reset() {
|
|
|
r.index = 0
|
|
|
+ r.validationIndex = len(r.dataSet) - r.validationCount
|
|
|
}
|
|
|
|
|
|
func (r *TextDataReader) Index() int {
|