snakesimulator.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. package snakesimulator
  2. import (
  3. context "context"
  4. fmt "fmt"
  5. math "math"
  6. "math/rand"
  7. "net"
  8. "time"
  9. "gonum.org/v1/gonum/mat"
  10. genetic "../genetic"
  11. grpc "google.golang.org/grpc"
  12. )
  13. type SnakeSimulator struct {
  14. field *Field
  15. snake *Snake
  16. stats *Stats
  17. speed uint32
  18. fieldUpdateQueue chan bool
  19. snakeUpdateQueue chan bool
  20. statsUpdateQueue chan bool
  21. speedQueue chan uint32
  22. }
  23. func NewSnakeSimulator() (s *SnakeSimulator) {
  24. s = &SnakeSimulator{
  25. field: &Field{
  26. Food: &Point{},
  27. Width: 40,
  28. Height: 40,
  29. },
  30. snake: &Snake{
  31. Points: []*Point{
  32. &Point{X: 20, Y: 20},
  33. &Point{X: 21, Y: 20},
  34. &Point{X: 22, Y: 20},
  35. },
  36. },
  37. stats: &Stats{},
  38. fieldUpdateQueue: make(chan bool, 2),
  39. snakeUpdateQueue: make(chan bool, 2),
  40. statsUpdateQueue: make(chan bool, 2),
  41. speedQueue: make(chan uint32, 1),
  42. speed: 10,
  43. }
  44. return
  45. }
  46. func (s *SnakeSimulator) Verify(population *genetic.Population) (results []*genetic.IndividalResult) {
  47. s.stats.Generation++
  48. s.statsUpdateQueue <- true
  49. s.field.GenerateNextFood()
  50. results = make([]*genetic.IndividalResult, len(population.Networks))
  51. for index, inidividual := range population.Networks {
  52. s.stats.Individual = uint32(index)
  53. s.statsUpdateQueue <- true
  54. rand.Seed(time.Now().UnixNano())
  55. switch rand.Uint32() % 4 {
  56. case 1:
  57. s.snake = &Snake{
  58. Points: []*Point{
  59. &Point{X: 20, Y: 20},
  60. &Point{X: 21, Y: 20},
  61. &Point{X: 22, Y: 20},
  62. },
  63. }
  64. case 2:
  65. s.snake = &Snake{
  66. Points: []*Point{
  67. &Point{X: 22, Y: 20},
  68. &Point{X: 21, Y: 20},
  69. &Point{X: 20, Y: 20},
  70. },
  71. }
  72. case 3:
  73. s.snake = &Snake{
  74. Points: []*Point{
  75. &Point{X: 20, Y: 20},
  76. &Point{X: 20, Y: 21},
  77. &Point{X: 20, Y: 22},
  78. },
  79. }
  80. default:
  81. s.snake = &Snake{
  82. Points: []*Point{
  83. &Point{X: 20, Y: 22},
  84. &Point{X: 20, Y: 21},
  85. &Point{X: 20, Y: 20},
  86. },
  87. }
  88. }
  89. i := 0
  90. s.stats.Move = 0
  91. for i < 300 {
  92. if s.speed > 0 {
  93. s.statsUpdateQueue <- true
  94. }
  95. //Read speed from client
  96. select {
  97. case newSpeed := <-s.speedQueue:
  98. fmt.Printf("Apply new speed: %v\n", newSpeed)
  99. if newSpeed < 10 {
  100. if newSpeed > 0 {
  101. s.speed = newSpeed
  102. } else {
  103. s.speed = 0
  104. }
  105. }
  106. default:
  107. }
  108. i++
  109. if s.speed > 0 {
  110. s.snakeUpdateQueue <- true
  111. }
  112. direction, _ := inidividual.Predict(mat.NewDense(inidividual.Sizes[0], 1, s.GetHeadState()))
  113. newHead := s.snake.NewHead(Direction(direction + 1))
  114. if selfCollisionIndex := s.snake.SelfCollision(newHead); selfCollisionIndex > 0 {
  115. if selfCollisionIndex == 1 {
  116. // switch Direction(direction + 1) {
  117. // case Direction_Up:
  118. // newHead = s.snake.NewHead(Direction_Down)
  119. // case Direction_Down:
  120. // newHead = s.snake.NewHead(Direction_Up)
  121. // case Direction_Left:
  122. // newHead = s.snake.NewHead(Direction_Right)
  123. // default:
  124. // newHead = s.snake.NewHead(Direction_Left)
  125. // }
  126. continue
  127. } else {
  128. fmt.Printf("Game over self collision\n")
  129. break
  130. }
  131. }
  132. if newHead.X == s.field.Food.X && newHead.Y == s.field.Food.Y {
  133. s.snake.Feed(newHead)
  134. s.field.GenerateNextFood()
  135. if s.speed > 0 {
  136. s.fieldUpdateQueue <- true
  137. }
  138. } else if newHead.X >= s.field.Width || newHead.Y >= s.field.Height {
  139. // fmt.Printf("Game over\n")
  140. // time.Sleep(1000 * time.Millisecond)
  141. break
  142. } else {
  143. s.snake.Move(newHead)
  144. }
  145. if s.speed > 0 {
  146. time.Sleep(100 / time.Duration(s.speed) * time.Millisecond)
  147. }
  148. s.stats.Move++
  149. }
  150. results[index] = &genetic.IndividalResult{
  151. Result: float64(len(s.snake.Points)-2) * float64(s.stats.Move),
  152. // Result: float64(s.stats.Move),
  153. Index: index,
  154. }
  155. }
  156. return
  157. }
  158. func (s *SnakeSimulator) GetHeadState() []float64 {
  159. headX := float64(s.snake.Points[0].X)
  160. headY := float64(s.snake.Points[0].Y)
  161. tailX := float64(s.snake.Points[len(s.snake.Points)-1].X)
  162. tailY := float64(s.snake.Points[len(s.snake.Points)-1].Y)
  163. // prevX := float64(s.snake.Points[1].X)
  164. // prevY := float64(s.snake.Points[1].Y)
  165. foodX := float64(s.field.Food.X)
  166. foodY := float64(s.field.Food.Y)
  167. width := float64(s.field.Width)
  168. height := float64(s.field.Height)
  169. diag := float64(width) * math.Sqrt2
  170. // tBack := float64(0.0)
  171. // bBack := float64(0.0)
  172. // lBack := float64(0.0)
  173. // rBack := float64(0.0)
  174. // if prevX == headX {
  175. // if prevY > headY {
  176. // bBack = 1.0
  177. // } else {
  178. // tBack = 1.0
  179. // }
  180. // }
  181. // if prevY == headY {
  182. // if prevX > headX {
  183. // lBack = 1.0
  184. // } else {
  185. // rBack = 1.0
  186. // }
  187. // }
  188. lWall := headX
  189. rWall := (width - headX)
  190. tWall := headY
  191. bWall := (height - headY)
  192. lFood := float64(width)
  193. rFood := float64(width)
  194. tFood := float64(height)
  195. bFood := float64(height)
  196. tlFood := float64(diag)
  197. trFood := float64(diag)
  198. blFood := float64(diag)
  199. brFood := float64(diag)
  200. tlWall := float64(0)
  201. trWall := float64(0)
  202. blWall := float64(0)
  203. brWall := float64(0)
  204. if headX == foodX {
  205. tFood = headY - foodY
  206. if tFood < 0 {
  207. tFood = height
  208. }
  209. bFood = foodY - headY
  210. if bFood < 0 {
  211. bFood = height
  212. }
  213. }
  214. if headY == foodY {
  215. rFood = foodX - headX
  216. if rFood < 0 {
  217. rFood = width
  218. }
  219. lFood = headX - foodX
  220. if lFood < 0 {
  221. lFood = width
  222. }
  223. }
  224. if math.Abs(foodY-headY) == math.Abs(foodX-headX) {
  225. if foodX > headX {
  226. if foodY > headY {
  227. trFood = math.Abs(foodX-headX) * math.Sqrt2
  228. } else {
  229. brFood = math.Abs(foodX-headX) * math.Sqrt2
  230. }
  231. } else {
  232. if foodY > headY {
  233. tlFood = math.Abs(foodX-headX) * math.Sqrt2
  234. } else {
  235. blFood = math.Abs(foodX-headX) * math.Sqrt2
  236. }
  237. }
  238. }
  239. if lWall > tWall {
  240. tlWall = float64(tWall) * math.Sqrt2
  241. } else {
  242. tlWall = float64(lWall) * math.Sqrt2
  243. }
  244. if rWall > tWall {
  245. trWall = float64(tWall) * math.Sqrt2
  246. } else {
  247. trWall = float64(rWall) * math.Sqrt2
  248. }
  249. if lWall > bWall {
  250. blWall = float64(bWall) * math.Sqrt2
  251. } else {
  252. blWall = float64(lWall) * math.Sqrt2
  253. }
  254. if rWall > bWall {
  255. blWall = float64(bWall) * math.Sqrt2
  256. } else {
  257. brWall = float64(rWall) * math.Sqrt2
  258. }
  259. if lWall > tWall {
  260. tlWall = float64(tWall) * math.Sqrt2
  261. } else {
  262. tlWall = float64(lWall) * math.Sqrt2
  263. }
  264. if rWall > tWall {
  265. trWall = float64(tWall) * math.Sqrt2
  266. } else {
  267. trWall = float64(rWall) * math.Sqrt2
  268. }
  269. if lWall > bWall {
  270. blWall = float64(bWall) * math.Sqrt2
  271. } else {
  272. blWall = float64(lWall) * math.Sqrt2
  273. }
  274. if rWall > bWall {
  275. blWall = float64(bWall) * math.Sqrt2
  276. } else {
  277. brWall = float64(rWall) * math.Sqrt2
  278. }
  279. tTail := (headY - tailY)
  280. if tTail < 0 {
  281. tTail = height
  282. }
  283. bTail := (tailY - headY)
  284. if bTail < 0 {
  285. bTail = height
  286. }
  287. lTail := (headX - tailX)
  288. if lTail < 0 {
  289. tTail = width
  290. }
  291. rTail := (tailX - headX)
  292. if lTail < 0 {
  293. tTail = width
  294. }
  295. tlTail := float64(diag)
  296. trTail := float64(diag)
  297. blTail := float64(diag)
  298. brTail := float64(diag)
  299. if math.Abs(headY-tailY) == math.Abs(headX-tailX) {
  300. if tailY > headY {
  301. if tailX > headX {
  302. trTail = math.Abs(tailX-headX) * math.Sqrt2
  303. } else {
  304. tlTail = math.Abs(tailX-headX) * math.Sqrt2
  305. }
  306. } else {
  307. if tailX > headX {
  308. brTail = math.Abs(tailX-headX) * math.Sqrt2
  309. } else {
  310. blTail = math.Abs(tailX-headX) * math.Sqrt2
  311. }
  312. }
  313. }
  314. return []float64{
  315. lWall / width,
  316. rWall / width,
  317. tWall / height,
  318. bWall / height,
  319. (1.0 - lFood/width),
  320. (1.0 - rFood/width),
  321. (1.0 - tFood/height),
  322. (1.0 - bFood/height),
  323. tlWall / diag,
  324. trWall / diag,
  325. blWall / diag,
  326. brWall / diag,
  327. (1.0 - tlFood/diag),
  328. (1.0 - trFood/diag),
  329. (1.0 - blFood/diag),
  330. (1.0 - brFood/diag),
  331. tTail / height,
  332. bTail / height,
  333. lTail / width,
  334. rTail / width,
  335. tlTail / diag,
  336. trTail / diag,
  337. blTail / diag,
  338. brTail / diag,
  339. }
  340. }
  341. func (s *SnakeSimulator) StartServer() {
  342. go func() {
  343. grpcServer := grpc.NewServer()
  344. RegisterSnakeSimulatorServer(grpcServer, s)
  345. lis, err := net.Listen("tcp", "localhost:65002")
  346. if err != nil {
  347. fmt.Printf("Failed to listen: %v\n", err)
  348. }
  349. fmt.Printf("Listen SnakeSimulator localhost:65002\n")
  350. if err := grpcServer.Serve(lis); err != nil {
  351. fmt.Printf("Failed to serve: %v\n", err)
  352. }
  353. }()
  354. }
  355. func (s *SnakeSimulator) Run() {
  356. s.field.GenerateNextFood()
  357. for true {
  358. direction := rand.Int31()%4 + 1
  359. newHead := s.snake.NewHead(Direction(direction))
  360. if newHead.X == s.field.Food.X && newHead.Y == s.field.Food.Y {
  361. s.snake.Feed(newHead)
  362. s.field.GenerateNextFood()
  363. s.fieldUpdateQueue <- true
  364. } else if newHead.X > s.field.Width || newHead.Y > s.field.Height {
  365. fmt.Printf("Game over\n")
  366. break
  367. } else if selfCollisionIndex := s.snake.SelfCollision(newHead); selfCollisionIndex > 0 {
  368. if selfCollisionIndex == 1 {
  369. fmt.Printf("Step backward, skip\n")
  370. continue
  371. }
  372. fmt.Printf("Game over self collision\n")
  373. break
  374. } else {
  375. s.snake.Move(newHead)
  376. }
  377. s.snakeUpdateQueue <- true
  378. time.Sleep(50 * time.Millisecond)
  379. }
  380. }
  381. func (s *SnakeSimulator) Field(_ *None, srv SnakeSimulator_FieldServer) error {
  382. ctx := srv.Context()
  383. for {
  384. select {
  385. case <-ctx.Done():
  386. return ctx.Err()
  387. default:
  388. }
  389. srv.Send(s.field)
  390. <-s.fieldUpdateQueue
  391. }
  392. }
  393. func (s *SnakeSimulator) Snake(_ *None, srv SnakeSimulator_SnakeServer) error {
  394. ctx := srv.Context()
  395. for {
  396. select {
  397. case <-ctx.Done():
  398. return ctx.Err()
  399. default:
  400. }
  401. srv.Send(s.snake)
  402. <-s.snakeUpdateQueue
  403. }
  404. }
  405. func (s *SnakeSimulator) Stats(_ *None, srv SnakeSimulator_StatsServer) error {
  406. ctx := srv.Context()
  407. for {
  408. select {
  409. case <-ctx.Done():
  410. return ctx.Err()
  411. default:
  412. }
  413. srv.Send(s.stats)
  414. <-s.statsUpdateQueue
  415. }
  416. }
  417. func (s *SnakeSimulator) SetSpeed(ctx context.Context, speed *Speed) (*None, error) {
  418. s.speedQueue <- speed.Speed
  419. return &None{}, nil
  420. }