snakesimulator.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. package snakesimulator
  2. import (
  3. fmt "fmt"
  4. math "math"
  5. "math/rand"
  6. "net"
  7. "time"
  8. "gonum.org/v1/gonum/mat"
  9. genetic "../genetic"
  10. grpc "google.golang.org/grpc"
  11. )
  12. type SnakeSimulator struct {
  13. field *Field
  14. snake *Snake
  15. stats *Stats
  16. fieldUpdateQueue chan bool
  17. snakeUpdateQueue chan bool
  18. statsUpdateQueue chan bool
  19. }
  20. func NewSnakeSimulator() (s *SnakeSimulator) {
  21. s = &SnakeSimulator{
  22. field: &Field{
  23. Food: &Point{},
  24. Width: 40,
  25. Height: 40,
  26. },
  27. snake: &Snake{
  28. Points: []*Point{
  29. &Point{X: 20, Y: 20},
  30. &Point{X: 21, Y: 20},
  31. &Point{X: 22, Y: 20},
  32. },
  33. },
  34. stats: &Stats{},
  35. fieldUpdateQueue: make(chan bool, 2),
  36. snakeUpdateQueue: make(chan bool, 2),
  37. statsUpdateQueue: make(chan bool, 2),
  38. }
  39. return
  40. }
  41. func (s *SnakeSimulator) Verify(population *genetic.Population) (results []*genetic.IndividalResult) {
  42. s.stats.Generation++
  43. s.statsUpdateQueue <- true
  44. results = make([]*genetic.IndividalResult, len(population.Networks))
  45. for index, inidividual := range population.Networks {
  46. s.stats.Individual = uint32(index)
  47. s.statsUpdateQueue <- true
  48. s.field.GenerateNextFood()
  49. s.snake = &Snake{
  50. Points: []*Point{
  51. &Point{X: 20, Y: 20},
  52. &Point{X: 21, Y: 20},
  53. &Point{X: 22, Y: 20},
  54. },
  55. }
  56. i := 0
  57. for i < 300 {
  58. s.stats.Move = uint32(i)
  59. s.statsUpdateQueue <- true
  60. i++
  61. s.snakeUpdateQueue <- true
  62. direction, _ := inidividual.Predict(mat.NewDense(inidividual.Sizes[0], 1, s.GetHeadState()))
  63. newHead := s.snake.NewHead(Direction(direction + 1))
  64. if newHead.X == s.field.Food.X && newHead.Y == s.field.Food.Y {
  65. s.snake.Feed(newHead)
  66. s.field.GenerateNextFood()
  67. s.fieldUpdateQueue <- true
  68. } else if newHead.X >= s.field.Width || newHead.Y >= s.field.Height {
  69. fmt.Printf("Game over\n")
  70. // time.Sleep(1000 * time.Millisecond)
  71. break
  72. } else if selfCollisionIndex := s.snake.SelfCollision(newHead); selfCollisionIndex > 0 {
  73. if selfCollisionIndex == 1 {
  74. fmt.Printf("Step backward, skip\n")
  75. continue
  76. }
  77. fmt.Printf("Game over self collision\n")
  78. break
  79. } else {
  80. s.snake.Move(newHead)
  81. }
  82. time.Sleep(10 * time.Millisecond)
  83. }
  84. results[index] = &genetic.IndividalResult{
  85. Result: float64(len(s.snake.Points)) * float64(i),
  86. Index: index,
  87. }
  88. }
  89. return
  90. }
  91. func (s *SnakeSimulator) GetHeadState() []float64 {
  92. headX := int32(s.snake.Points[0].X)
  93. headY := int32(s.snake.Points[0].Y)
  94. foodX := int32(s.field.Food.X)
  95. foodY := int32(s.field.Food.Y)
  96. width := int32(s.field.Width)
  97. height := int32(s.field.Height)
  98. diag := float64(width) * math.Sqrt2
  99. lWall := headX
  100. rWall := width - headX
  101. tWall := headY
  102. bWall := height - headY
  103. lFood := int32(0)
  104. rFood := int32(0)
  105. tFood := int32(0)
  106. bFood := int32(0)
  107. tlFood := float64(0)
  108. trFood := float64(0)
  109. blFood := float64(0)
  110. brFood := float64(0)
  111. tlWall := float64(0)
  112. trWall := float64(0)
  113. blWall := float64(0)
  114. brWall := float64(0)
  115. if foodX == headX {
  116. if foodY > headY {
  117. bFood = foodY - headY
  118. } else {
  119. tFood = headY - foodY
  120. }
  121. }
  122. if foodY == headY {
  123. if foodX > headX {
  124. rFood = foodX - headX
  125. } else {
  126. lFood = headX - foodX
  127. }
  128. }
  129. if lWall > tWall {
  130. tlWall = float64(tWall) * math.Sqrt2
  131. } else {
  132. tlWall = float64(lWall) * math.Sqrt2
  133. }
  134. if rWall > tWall {
  135. trWall = float64(tWall) * math.Sqrt2
  136. } else {
  137. trWall = float64(rWall) * math.Sqrt2
  138. }
  139. if lWall > bWall {
  140. blWall = float64(bWall) * math.Sqrt2
  141. } else {
  142. blWall = float64(lWall) * math.Sqrt2
  143. }
  144. if rWall > bWall {
  145. blWall = float64(bWall) * math.Sqrt2
  146. } else {
  147. brWall = float64(rWall) * math.Sqrt2
  148. }
  149. foodDiagXDiff := math.Abs(float64(foodX - headX))
  150. foodDiagYDiff := math.Abs(float64(foodY - headY))
  151. if foodDiagXDiff == foodDiagYDiff {
  152. if math.Signbit(float64(foodX - headX)) {
  153. if math.Signbit(float64(foodY - headY)) {
  154. trFood = foodDiagXDiff * math.Sqrt2
  155. } else {
  156. brFood = foodDiagXDiff * math.Sqrt2
  157. }
  158. } else {
  159. if math.Signbit(float64(foodY - headY)) {
  160. tlFood = foodDiagXDiff * math.Sqrt2
  161. } else {
  162. blFood = foodDiagXDiff * math.Sqrt2
  163. }
  164. }
  165. }
  166. return []float64{
  167. float64(lWall) / float64(width),
  168. float64(rWall) / float64(width),
  169. float64(tWall) / float64(height),
  170. float64(bWall) / float64(height),
  171. float64(lFood) / float64(width),
  172. float64(rFood) / float64(width),
  173. float64(tFood) / float64(height),
  174. float64(bFood) / float64(height),
  175. float64(tlFood) / diag,
  176. float64(trFood) / diag,
  177. float64(blFood) / diag,
  178. float64(brFood) / diag,
  179. float64(tlWall) / diag,
  180. float64(trWall) / diag,
  181. float64(blWall) / diag,
  182. float64(brWall) / diag,
  183. }
  184. }
  185. func (s *SnakeSimulator) StartServer() {
  186. go func() {
  187. grpcServer := grpc.NewServer()
  188. RegisterSnakeSimulatorServer(grpcServer, s)
  189. lis, err := net.Listen("tcp", "localhost:65002")
  190. if err != nil {
  191. fmt.Printf("Failed to listen: %v\n", err)
  192. }
  193. fmt.Printf("Listen SnakeSimulator localhost:65002\n")
  194. if err := grpcServer.Serve(lis); err != nil {
  195. fmt.Printf("Failed to serve: %v\n", err)
  196. }
  197. }()
  198. }
  199. func (s *SnakeSimulator) Run() {
  200. s.field.GenerateNextFood()
  201. for true {
  202. direction := rand.Int31()%4 + 1
  203. newHead := s.snake.NewHead(Direction(direction))
  204. if newHead.X == s.field.Food.X && newHead.Y == s.field.Food.Y {
  205. s.snake.Feed(newHead)
  206. s.field.GenerateNextFood()
  207. s.fieldUpdateQueue <- true
  208. } else if newHead.X >= s.field.Width || newHead.Y >= s.field.Height {
  209. fmt.Printf("Game over\n")
  210. break
  211. } else if selfCollisionIndex := s.snake.SelfCollision(newHead); selfCollisionIndex > 0 {
  212. if selfCollisionIndex == 1 {
  213. fmt.Printf("Step backward, skip\n")
  214. continue
  215. }
  216. fmt.Printf("Game over self collision\n")
  217. break
  218. } else {
  219. s.snake.Move(newHead)
  220. }
  221. s.snakeUpdateQueue <- true
  222. time.Sleep(50 * time.Millisecond)
  223. }
  224. }
  225. func (s *SnakeSimulator) Field(_ *None, srv SnakeSimulator_FieldServer) error {
  226. ctx := srv.Context()
  227. for {
  228. select {
  229. case <-ctx.Done():
  230. return ctx.Err()
  231. default:
  232. }
  233. srv.Send(s.field)
  234. <-s.fieldUpdateQueue
  235. }
  236. }
  237. func (s *SnakeSimulator) Snake(_ *None, srv SnakeSimulator_SnakeServer) error {
  238. ctx := srv.Context()
  239. for {
  240. select {
  241. case <-ctx.Done():
  242. return ctx.Err()
  243. default:
  244. }
  245. srv.Send(s.snake)
  246. <-s.snakeUpdateQueue
  247. }
  248. }
  249. func (s *SnakeSimulator) Stats(_ *None, srv SnakeSimulator_StatsServer) error {
  250. ctx := srv.Context()
  251. for {
  252. select {
  253. case <-ctx.Done():
  254. return ctx.Err()
  255. default:
  256. }
  257. srv.Send(s.stats)
  258. <-s.statsUpdateQueue
  259. }
  260. }