Pārlūkot izejas kodu

Fix issues with selection of parents

Alexey Edelev 5 gadi atpakaļ
vecāks
revīzija
a302c914e2

+ 11 - 3
neuralnetwork/genetic/genetic.go

@@ -1,7 +1,9 @@
 package genetic
 
 import (
+	"fmt"
 	"log"
+	"math/rand"
 
 	"sort"
 
@@ -51,8 +53,8 @@ func (p *Population) NaturalSelection(generationCount int) {
 }
 
 func (p *Population) crossbreedPopulation(results []*IndividalResult) {
-	sort.Slice(results, func(i, j int) bool {
-		return results[i].Result < results[j].Result
+	sort.SliceStable(results, func(i, j int) bool {
+		return results[i].Result > results[j].Result //Descent order best will be on top, worst in the bottom
 	})
 
 	etalons := int(float64(p.populationConfig.PopulationSize) * p.populationConfig.SelectionSize)
@@ -61,6 +63,7 @@ func (p *Population) crossbreedPopulation(results []*IndividalResult) {
 		secondParentBase := results[(i-1)%etalons].Index
 		firstParent := results[i].Index
 		secondParent := results[i-1].Index
+		fmt.Printf("Result index %v value %v i %v\n", results[i].Index, results[i].Result, i)
 		p.Networks[firstParent] = p.Networks[firstParentBase].Copy()
 		p.Networks[secondParent] = p.Networks[secondParentBase].Copy()
 		crossbreed(p.Networks[firstParent], p.Networks[secondParent], p.populationConfig.CrossbreedPart)
@@ -76,7 +79,12 @@ func crossbreed(firstParent, secondParent *neuralnetwork.NeuralNetwork, crossbre
 		firstParentBiases := firstParent.Biases[l]
 		secondParentBiases := secondParent.Biases[l]
 		r, c := firstParentWeights.Dims()
-		for i := 0; i < int(float64(r)*crossbreedPart); i++ {
+		rp := int(float64(r) * crossbreedPart)
+		cp := int(float64(c) * crossbreedPart)
+		r = int(rand.Uint32())%(r-rp) + rp
+		c = int(rand.Uint32())%(c-cp) + cp
+		// for i := 0; i < int(float64(r)*crossbreedPart); i++ {
+		for i := 0; i < r; i++ {
 			for j := 0; j < c; j++ {
 				// Swap first half of weights
 				w := firstParentWeights.At(i, j)

+ 3 - 1
neuralnetwork/genetic/mutagen/dummymutagen.go

@@ -27,8 +27,10 @@ func (rm *DummyMutagen) Mutate(network *neuralnetwork.NeuralNetwork) {
 			for o := 0; o < 10; o++ {
 				mutationRow := int(rand.Uint32()) % r
 				mutationColumn := int(rand.Uint32()) % c
-				weight := rand.Float64()
+				weight := rand.NormFloat64()
+				bias := rand.NormFloat64()
 				network.Weights[l].Set(mutationRow, mutationColumn, weight)
+				network.Biases[l].Set(mutationRow, 0, bias)
 			}
 		}
 	}

+ 1 - 1
neuralnetwork/main.go

@@ -12,7 +12,7 @@ func main() {
 	go rc.Run()
 	s := snakesimulator.NewSnakeSimulator()
 	s.StartServer()
-	p := genetic.NewPopulation(s, mutagen.NewDummyMutagen(50), genetic.PopulationConfig{PopulationSize: 400, SelectionSize: 0.1, CrossbreedPart: 0.5}, []int{20, 18, 18, 4})
+	p := genetic.NewPopulation(s, mutagen.NewDummyMutagen(50), genetic.PopulationConfig{PopulationSize: 500, SelectionSize: 0.05, CrossbreedPart: 0.2}, []int{24, 20, 20, 4})
 	for _, net := range p.Networks {
 		net.SetStateWatcher(rc)
 	}

+ 3 - 0
neuralnetwork/neuralnetwork/mathcommon.go

@@ -36,8 +36,11 @@ import (
 func generateRandomDense(rows, columns int) *mat.Dense {
 	rand.Seed(time.Now().UnixNano())
 	data := make([]float64, rows*columns)
+	// min := -1.0
+	// max := 1.0
 	for i := range data {
 		data[i] = rand.NormFloat64()
+		// data[i] = min + rand.Float64()*(max-min)
 	}
 	return mat.NewDense(rows, columns, data)
 }

+ 156 - 47
neuralnetwork/snakesimulator/snakesimulator.go

@@ -52,13 +52,12 @@ func NewSnakeSimulator() (s *SnakeSimulator) {
 func (s *SnakeSimulator) Verify(population *genetic.Population) (results []*genetic.IndividalResult) {
 	s.stats.Generation++
 	s.statsUpdateQueue <- true
-
+	s.field.GenerateNextFood()
 	results = make([]*genetic.IndividalResult, len(population.Networks))
 	for index, inidividual := range population.Networks {
 		s.stats.Individual = uint32(index)
 		s.statsUpdateQueue <- true
 
-		s.field.GenerateNextFood()
 		rand.Seed(time.Now().UnixNano())
 		switch rand.Uint32() % 4 {
 		case 1:
@@ -98,7 +97,10 @@ func (s *SnakeSimulator) Verify(population *genetic.Population) (results []*gene
 		i := 0
 		s.stats.Move = 0
 		for i < 300 {
-			s.statsUpdateQueue <- true
+			if s.speed > 0 {
+				s.statsUpdateQueue <- true
+			}
+
 			//Read speed from client
 			select {
 			case newSpeed := <-s.speedQueue:
@@ -118,34 +120,49 @@ func (s *SnakeSimulator) Verify(population *genetic.Population) (results []*gene
 			}
 			direction, _ := inidividual.Predict(mat.NewDense(inidividual.Sizes[0], 1, s.GetHeadState()))
 			newHead := s.snake.NewHead(Direction(direction + 1))
+
+			if selfCollisionIndex := s.snake.SelfCollision(newHead); selfCollisionIndex > 0 {
+				if selfCollisionIndex == 1 {
+					// switch Direction(direction + 1) {
+					// case Direction_Up:
+					// 	newHead = s.snake.NewHead(Direction_Down)
+					// case Direction_Down:
+					// 	newHead = s.snake.NewHead(Direction_Up)
+					// case Direction_Left:
+					// 	newHead = s.snake.NewHead(Direction_Right)
+					// default:
+					// 	newHead = s.snake.NewHead(Direction_Left)
+					// }
+					continue
+				} else {
+					fmt.Printf("Game over self collision\n")
+					break
+				}
+			}
+
 			if newHead.X == s.field.Food.X && newHead.Y == s.field.Food.Y {
 				s.snake.Feed(newHead)
 				s.field.GenerateNextFood()
-				s.fieldUpdateQueue <- true
+				if s.speed > 0 {
+					s.fieldUpdateQueue <- true
+				}
 			} else if newHead.X >= s.field.Width || newHead.Y >= s.field.Height {
 				// fmt.Printf("Game over\n")
 				// time.Sleep(1000 * time.Millisecond)
 				break
-			} else if selfCollisionIndex := s.snake.SelfCollision(newHead); selfCollisionIndex > 0 {
-				if selfCollisionIndex == 1 {
-					// fmt.Printf("Step backward, skip\n")
-					continue
-				}
-				fmt.Printf("Game over self collision\n")
-				break
 			} else {
 				s.snake.Move(newHead)
 			}
-			s.stats.Move++
-
 			if s.speed > 0 {
 				time.Sleep(100 / time.Duration(s.speed) * time.Millisecond)
 			}
+			s.stats.Move++
 		}
 
 		results[index] = &genetic.IndividalResult{
 			Result: float64(len(s.snake.Points)-2) * float64(s.stats.Move),
-			Index:  index,
+			// Result: float64(s.stats.Move),
+			Index: index,
 		}
 	}
 	return
@@ -164,49 +181,105 @@ func (s *SnakeSimulator) GetHeadState() []float64 {
 	height := float64(s.field.Height)
 	diag := float64(width) * math.Sqrt2
 
-	tBack := float64(1.0)
-	bBack := float64(1.0)
-	lBack := float64(1.0)
-	rBack := float64(1.0)
+	// tBack := float64(0.0)
+	// bBack := float64(0.0)
+	// lBack := float64(0.0)
+	// rBack := float64(0.0)
 	// if prevX == headX {
 	// 	if prevY > headY {
-	// 		tBack = 0.0
+	// 		bBack = 1.0
 	// 	} else {
-	// 		bBack = 0.0
+	// 		tBack = 1.0
 	// 	}
 	// }
 
 	// if prevY == headY {
 	// 	if prevX > headX {
-	// 		rBack = 0.0
+	// 		lBack = 1.0
 	// 	} else {
-	// 		lBack = 0.0
+	// 		rBack = 1.0
 	// 	}
 	// }
 
-	lWall := headX * lBack
-	rWall := (width - headX) * rBack
-	tWall := headY * tBack
-	bWall := (height - headY) * bBack
-	lFood := float64(0)
-	rFood := float64(0)
-	tFood := float64(0)
-	bFood := float64(0)
-
-	tlFood := float64(0)
-	trFood := float64(0)
-	blFood := float64(0)
-	brFood := float64(0)
+	lWall := headX
+	rWall := (width - headX)
+	tWall := headY
+	bWall := (height - headY)
+	lFood := float64(width)
+	rFood := float64(width)
+	tFood := float64(height)
+	bFood := float64(height)
+
+	tlFood := float64(diag)
+	trFood := float64(diag)
+	blFood := float64(diag)
+	brFood := float64(diag)
 	tlWall := float64(0)
 	trWall := float64(0)
 	blWall := float64(0)
 	brWall := float64(0)
 
-	tFood = (1.0 - (headY-foodY)/height) * tBack
-	bFood = (1.0 - (foodY-headY)/height) * bBack
+	if headX == foodX {
+		tFood = headY - foodY
+		if tFood < 0 {
+			tFood = height
+		}
+		bFood = foodY - headY
+		if bFood < 0 {
+			bFood = height
+		}
+	}
+
+	if headY == foodY {
+		rFood = foodX - headX
+		if rFood < 0 {
+			rFood = width
+		}
+		lFood = headX - foodX
+		if lFood < 0 {
+			lFood = width
+		}
+	}
+
+	if math.Abs(foodY-headY) == math.Abs(foodX-headX) {
+		if foodX > headX {
+			if foodY > headY {
+				trFood = math.Abs(foodX-headX) * math.Sqrt2
+			} else {
+				brFood = math.Abs(foodX-headX) * math.Sqrt2
+			}
+		} else {
+			if foodY > headY {
+				tlFood = math.Abs(foodX-headX) * math.Sqrt2
+			} else {
+				blFood = math.Abs(foodX-headX) * math.Sqrt2
+			}
+		}
+	}
+
+	if lWall > tWall {
+		tlWall = float64(tWall) * math.Sqrt2
+	} else {
+		tlWall = float64(lWall) * math.Sqrt2
+	}
+
+	if rWall > tWall {
+		trWall = float64(tWall) * math.Sqrt2
+	} else {
+		trWall = float64(rWall) * math.Sqrt2
+	}
 
-	rFood = (1.0 - (foodX-headX)/width) * rBack
-	lFood = (1.0 - (headX-foodX)/width) * lBack
+	if lWall > bWall {
+		blWall = float64(bWall) * math.Sqrt2
+	} else {
+		blWall = float64(lWall) * math.Sqrt2
+	}
+
+	if rWall > bWall {
+		blWall = float64(bWall) * math.Sqrt2
+	} else {
+		brWall = float64(rWall) * math.Sqrt2
+	}
 
 	if lWall > tWall {
 		tlWall = float64(tWall) * math.Sqrt2
@@ -233,31 +306,67 @@ func (s *SnakeSimulator) GetHeadState() []float64 {
 	}
 
 	tTail := (headY - tailY)
+	if tTail < 0 {
+		tTail = height
+	}
 	bTail := (tailY - headY)
+	if bTail < 0 {
+		bTail = height
+	}
 	lTail := (headX - tailX)
+	if lTail < 0 {
+		tTail = width
+	}
 	rTail := (tailX - headX)
+	if lTail < 0 {
+		tTail = width
+	}
+
+	tlTail := float64(diag)
+	trTail := float64(diag)
+	blTail := float64(diag)
+	brTail := float64(diag)
+	if math.Abs(headY-tailY) == math.Abs(headX-tailX) {
+		if tailY > headY {
+			if tailX > headX {
+				trTail = math.Abs(tailX-headX) * math.Sqrt2
+			} else {
+				tlTail = math.Abs(tailX-headX) * math.Sqrt2
+			}
+		} else {
+			if tailX > headX {
+				brTail = math.Abs(tailX-headX) * math.Sqrt2
+			} else {
+				blTail = math.Abs(tailX-headX) * math.Sqrt2
+			}
+		}
+	}
 
 	return []float64{
 		lWall / width,
-		lFood,
 		rWall / width,
-		rFood,
 		tWall / height,
-		tFood,
 		bWall / height,
-		bFood,
+		(1.0 - lFood/width),
+		(1.0 - rFood/width),
+		(1.0 - tFood/height),
+		(1.0 - bFood/height),
 		tlWall / diag,
-		tlFood,
 		trWall / diag,
-		trFood,
 		blWall / diag,
-		blFood,
 		brWall / diag,
-		brFood,
+		(1.0 - tlFood/diag),
+		(1.0 - trFood/diag),
+		(1.0 - blFood/diag),
+		(1.0 - brFood/diag),
 		tTail / height,
 		bTail / height,
 		lTail / width,
 		rTail / width,
+		tlTail / diag,
+		trTail / diag,
+		blTail / diag,
+		brTail / diag,
 	}
 }