snakesimulator.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. package snakesimulator
  2. import (
  3. context "context"
  4. fmt "fmt"
  5. math "math"
  6. "math/rand"
  7. "net"
  8. "sort"
  9. "time"
  10. "gonum.org/v1/gonum/mat"
  11. genetic "../genetic"
  12. neuralnetwork "../neuralnetwork"
  13. grpc "google.golang.org/grpc"
  14. )
  15. type SnakeSimulator struct {
  16. field *Field
  17. snake *Snake
  18. maxVerificationSteps int
  19. stats *Stats
  20. //GUI interface part
  21. speed uint32
  22. fieldUpdateQueue chan bool
  23. snakeUpdateQueue chan bool
  24. statsUpdateQueue chan bool
  25. speedQueue chan uint32
  26. }
  27. // Initializes new snake population with maximum number of verification steps
  28. func NewSnakeSimulator(maxVerificationSteps int) (s *SnakeSimulator) {
  29. s = &SnakeSimulator{
  30. field: &Field{
  31. Food: &Point{},
  32. Width: 40,
  33. Height: 40,
  34. },
  35. snake: &Snake{
  36. Points: []*Point{
  37. &Point{X: 20, Y: 20},
  38. &Point{X: 21, Y: 20},
  39. &Point{X: 22, Y: 20},
  40. },
  41. },
  42. stats: &Stats{},
  43. maxVerificationSteps: maxVerificationSteps,
  44. fieldUpdateQueue: make(chan bool, 2),
  45. snakeUpdateQueue: make(chan bool, 2),
  46. statsUpdateQueue: make(chan bool, 2),
  47. speedQueue: make(chan uint32, 1),
  48. speed: 10,
  49. }
  50. return
  51. }
  52. // Population test method
  53. // Verifies population and returns unsorted finteses for each individual
  54. func (s *SnakeSimulator) Verify(population *genetic.Population) (fitnesses []*genetic.IndividalFitness) {
  55. s.stats.Generation++
  56. s.statsUpdateQueue <- true
  57. s.field.GenerateNextFood()
  58. s.fieldUpdateQueue <- true
  59. fitnesses = make([]*genetic.IndividalFitness, len(population.Networks))
  60. for index, inidividual := range population.Networks {
  61. s.stats.Individual = uint32(index)
  62. s.statsUpdateQueue <- true
  63. s.runSnake(inidividual, false)
  64. fitnesses[index] = &genetic.IndividalFitness{
  65. Fitness: float64(s.stats.Move),
  66. Index: index,
  67. }
  68. }
  69. //This is duplication of crossbreedPopulation functionality to display best snake
  70. sort.Slice(fitnesses, func(i, j int) bool {
  71. return fitnesses[i].Fitness > fitnesses[j].Fitness //Descent order best will be on top, worst in the bottom
  72. })
  73. //Best snake showtime!
  74. prevSpeed := s.speed
  75. s.speed = 2
  76. s.runSnake(population.Networks[fitnesses[0].Index], false)
  77. s.speed = prevSpeed
  78. return
  79. }
  80. func (s *SnakeSimulator) runSnake(inidividual *neuralnetwork.NeuralNetwork, randomStart bool) {
  81. if randomStart {
  82. rand.Seed(time.Now().UnixNano())
  83. s.snake = NewSnake(Direction(rand.Uint32()%4), *s.field)
  84. } else {
  85. s.snake = NewSnake(Direction_Left, *s.field)
  86. }
  87. s.stats.Move = 0
  88. for i := 0; i < s.maxVerificationSteps; i++ {
  89. //Read speed from client and sleep in case if user selected slow preview
  90. select {
  91. case newSpeed := <-s.speedQueue:
  92. fmt.Printf("Apply new speed: %v\n", newSpeed)
  93. if newSpeed <= 10 && newSpeed >= 0 {
  94. s.speed = newSpeed
  95. } else if newSpeed < 0 {
  96. s.speed = 0
  97. }
  98. default:
  99. }
  100. if s.speed > 0 {
  101. time.Sleep(100 / time.Duration(s.speed) * time.Millisecond)
  102. s.statsUpdateQueue <- true
  103. s.snakeUpdateQueue <- true
  104. }
  105. predictIndex, _ := inidividual.Predict(mat.NewDense(inidividual.Sizes[0], 1, s.getHeadState()))
  106. direction := Direction(predictIndex + 1)
  107. newHead := s.snake.NewHead(direction)
  108. if s.snake.selfCollision(newHead, direction) {
  109. fmt.Printf("Game over self collision\n")
  110. break
  111. } else if wallCollision(newHead, *s.field) {
  112. break
  113. } else if foodCollision(newHead, s.field.Food) {
  114. i = 0
  115. s.snake.Feed(newHead)
  116. s.field.GenerateNextFood()
  117. s.fieldUpdateQueue <- true
  118. } else {
  119. s.snake.Move(newHead)
  120. }
  121. s.stats.Move++
  122. }
  123. }
  124. // Produces input activations for neural network
  125. func (s *SnakeSimulator) getHeadState() []float64 {
  126. // Snake state
  127. headX := float64(s.snake.Points[0].X)
  128. headY := float64(s.snake.Points[0].Y)
  129. tailX := float64(s.snake.Points[len(s.snake.Points)-1].X)
  130. tailY := float64(s.snake.Points[len(s.snake.Points)-1].Y)
  131. // Field state
  132. foodX := float64(s.field.Food.X)
  133. foodY := float64(s.field.Food.Y)
  134. width := float64(s.field.Width)
  135. height := float64(s.field.Height)
  136. diag := float64(width) * math.Sqrt2 //We assume that field is always square
  137. // Output activations
  138. // Distance to walls in 4 directions
  139. lWall := headX
  140. rWall := (width - headX)
  141. tWall := headY
  142. bWall := (height - headY)
  143. // Distance to walls in 4 diagonal directions, by default is completely inactive
  144. tlWall := float64(0)
  145. trWall := float64(0)
  146. blWall := float64(0)
  147. brWall := float64(0)
  148. // Distance to food in 4 directions
  149. // By default is size of field that means that there is no activation at all
  150. lFood := float64(width)
  151. rFood := float64(width)
  152. tFood := float64(height)
  153. bFood := float64(height)
  154. // Distance to food in 4 diagonal directions
  155. // By default is size of field diagonal that means that there is no activation
  156. // at all
  157. tlFood := float64(diag)
  158. trFood := float64(diag)
  159. blFood := float64(diag)
  160. brFood := float64(diag)
  161. // Distance to tail in 4 directions
  162. tTail := float64(0)
  163. bTail := float64(0)
  164. lTail := float64(0)
  165. rTail := float64(0)
  166. // Distance to tail in 4 diagonal directions
  167. tlTail := float64(0)
  168. trTail := float64(0)
  169. blTail := float64(0)
  170. brTail := float64(0)
  171. // Diagonal distance to each wall
  172. if lWall > tWall {
  173. tlWall = float64(tWall) * math.Sqrt2
  174. } else {
  175. tlWall = float64(lWall) * math.Sqrt2
  176. }
  177. if rWall > tWall {
  178. trWall = float64(tWall) * math.Sqrt2
  179. } else {
  180. trWall = float64(rWall) * math.Sqrt2
  181. }
  182. if lWall > bWall {
  183. blWall = float64(bWall) * math.Sqrt2
  184. } else {
  185. blWall = float64(lWall) * math.Sqrt2
  186. }
  187. if rWall > bWall {
  188. blWall = float64(bWall) * math.Sqrt2
  189. } else {
  190. brWall = float64(rWall) * math.Sqrt2
  191. }
  192. // Check if food is on same vertical line with head and
  193. // choose vertical direction for activation
  194. if headX == foodX {
  195. if headY-foodY < 0 {
  196. bFood = 0
  197. } else {
  198. tFood = 0
  199. }
  200. }
  201. // Check if food is on same horizontal line with head and
  202. // choose horizontal direction for activation
  203. if headY == foodY {
  204. if foodX-headX < 0 {
  205. lFood = 0
  206. } else {
  207. rFood = 0
  208. }
  209. }
  210. //Check if food is on diagonal any of 4 ways
  211. if math.Abs(foodY-headY) == math.Abs(foodX-headX) {
  212. //Choose diagonal direction to food
  213. if foodX > headX {
  214. if foodY > headY {
  215. trFood = 0
  216. } else {
  217. brFood = 0
  218. }
  219. } else {
  220. if foodY > headY {
  221. tlFood = 0
  222. } else {
  223. blFood = 0
  224. }
  225. }
  226. }
  227. // Check if tail is on same vertical line with head and
  228. // choose vertical direction for activation
  229. if headX == tailX {
  230. if headY-tailY < 0 {
  231. bTail = height
  232. } else {
  233. tTail = height
  234. }
  235. }
  236. // Check if tail is on same horizontal line with head and
  237. // choose horizontal direction for activation
  238. if headY == tailY {
  239. if headX-tailX < 0 {
  240. rTail = width
  241. } else {
  242. lTail = width
  243. }
  244. }
  245. //Check if tail is on diagonal any of 4 ways
  246. if math.Abs(headY-tailY) == math.Abs(headX-tailX) {
  247. //Choose diagonal direction to tail
  248. if tailY > headY {
  249. if tailX > headX {
  250. trTail = diag
  251. } else {
  252. tlTail = diag
  253. }
  254. } else {
  255. if tailX > headX {
  256. brTail = diag
  257. } else {
  258. blTail = diag
  259. }
  260. }
  261. }
  262. return []float64{
  263. lWall / width,
  264. rWall / width,
  265. tWall / height,
  266. bWall / height,
  267. tlWall / diag,
  268. trWall / diag,
  269. blWall / diag,
  270. brWall / diag,
  271. (1.0 - lFood/width),
  272. (1.0 - rFood/width),
  273. (1.0 - tFood/height),
  274. (1.0 - bFood/height),
  275. (1.0 - tlFood/diag),
  276. (1.0 - trFood/diag),
  277. (1.0 - blFood/diag),
  278. (1.0 - brFood/diag),
  279. tTail / height,
  280. bTail / height,
  281. lTail / width,
  282. rTail / width,
  283. tlTail / diag,
  284. trTail / diag,
  285. blTail / diag,
  286. brTail / diag,
  287. }
  288. }
  289. // Server part
  290. // Runs gRPC server for GUI
  291. func (s *SnakeSimulator) StartServer() {
  292. go func() {
  293. grpcServer := grpc.NewServer()
  294. RegisterSnakeSimulatorServer(grpcServer, s)
  295. lis, err := net.Listen("tcp", "localhost:65002")
  296. if err != nil {
  297. fmt.Printf("Failed to listen: %v\n", err)
  298. }
  299. fmt.Printf("Listen SnakeSimulator localhost:65002\n")
  300. if err := grpcServer.Serve(lis); err != nil {
  301. fmt.Printf("Failed to serve: %v\n", err)
  302. }
  303. }()
  304. }
  305. // Steaming of Field updates
  306. func (s *SnakeSimulator) Field(_ *None, srv SnakeSimulator_FieldServer) error {
  307. ctx := srv.Context()
  308. for {
  309. select {
  310. case <-ctx.Done():
  311. return ctx.Err()
  312. default:
  313. }
  314. srv.Send(s.field)
  315. <-s.fieldUpdateQueue
  316. }
  317. }
  318. // Steaming of Snake position and length updates
  319. func (s *SnakeSimulator) Snake(_ *None, srv SnakeSimulator_SnakeServer) error {
  320. ctx := srv.Context()
  321. for {
  322. select {
  323. case <-ctx.Done():
  324. return ctx.Err()
  325. default:
  326. }
  327. srv.Send(s.snake)
  328. <-s.snakeUpdateQueue
  329. }
  330. }
  331. // Steaming of snake simulator statistic
  332. func (s *SnakeSimulator) Stats(_ *None, srv SnakeSimulator_StatsServer) error {
  333. ctx := srv.Context()
  334. for {
  335. select {
  336. case <-ctx.Done():
  337. return ctx.Err()
  338. default:
  339. }
  340. srv.Send(s.stats)
  341. <-s.statsUpdateQueue
  342. }
  343. }
  344. // Setup new speed requested from gRPC GUI client
  345. func (s *SnakeSimulator) SetSpeed(ctx context.Context, speed *Speed) (*None, error) {
  346. s.speedQueue <- speed.Speed
  347. return &None{}, nil
  348. }