snakesimulator.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. package snakesimulator
  2. import (
  3. context "context"
  4. fmt "fmt"
  5. math "math"
  6. "math/rand"
  7. "net"
  8. "sort"
  9. "time"
  10. "gonum.org/v1/gonum/mat"
  11. genetic "../genetic"
  12. neuralnetwork "../neuralnetwork"
  13. grpc "google.golang.org/grpc"
  14. )
  15. type SnakeSimulator struct {
  16. field *Field
  17. snake *Snake
  18. maxVerificationSteps int
  19. stats *Stats
  20. //GUI interface part
  21. speed uint32
  22. fieldUpdateQueue chan bool
  23. snakeUpdateQueue chan bool
  24. statsUpdateQueue chan bool
  25. speedQueue chan uint32
  26. }
  27. // Initializes new snake population with maximum number of verification steps
  28. func NewSnakeSimulator(maxVerificationSteps int) (s *SnakeSimulator) {
  29. s = &SnakeSimulator{
  30. field: &Field{
  31. Food: &Point{},
  32. Width: 40,
  33. Height: 40,
  34. },
  35. snake: &Snake{
  36. Points: []*Point{
  37. &Point{X: 20, Y: 20},
  38. &Point{X: 21, Y: 20},
  39. &Point{X: 22, Y: 20},
  40. },
  41. },
  42. stats: &Stats{},
  43. maxVerificationSteps: maxVerificationSteps,
  44. fieldUpdateQueue: make(chan bool, 2),
  45. snakeUpdateQueue: make(chan bool, 2),
  46. statsUpdateQueue: make(chan bool, 2),
  47. speedQueue: make(chan uint32, 1),
  48. speed: 10,
  49. }
  50. return
  51. }
  52. // Population test method
  53. // Verifies population and returns unsorted finteses for each individual
  54. func (s *SnakeSimulator) Verify(population *genetic.Population) (fitnesses []*genetic.IndividalFitness) {
  55. s.stats.Generation++
  56. s.statsUpdateQueue <- true
  57. s.field.GenerateNextFood()
  58. s.fieldUpdateQueue <- true
  59. fitnesses = make([]*genetic.IndividalFitness, len(population.Networks))
  60. for index, inidividual := range population.Networks {
  61. s.stats.Individual = uint32(index)
  62. s.statsUpdateQueue <- true
  63. s.runSnake(inidividual, false)
  64. fitnesses[index] = &genetic.IndividalFitness{
  65. // Fitness: float64(s.stats.Move), //Uncomment this to decrese food impact to individual selection
  66. Fitness: float64(s.stats.Move) * float64(len(s.snake.Points)-2),
  67. Index: index,
  68. }
  69. }
  70. //This is duplication of crossbreedPopulation functionality to display best snake
  71. sort.Slice(fitnesses, func(i, j int) bool {
  72. return fitnesses[i].Fitness > fitnesses[j].Fitness //Descent order best will be on top, worst in the bottom
  73. })
  74. //Best snake showtime!
  75. prevSpeed := s.speed
  76. s.speed = 5
  77. s.runSnake(population.Networks[fitnesses[0].Index], false)
  78. s.speed = prevSpeed
  79. return
  80. }
  81. func (s *SnakeSimulator) runSnake(inidividual *neuralnetwork.NeuralNetwork, randomStart bool) {
  82. if randomStart {
  83. rand.Seed(time.Now().UnixNano())
  84. s.snake = NewSnake(Direction(rand.Uint32()%4), *s.field)
  85. } else {
  86. s.snake = NewSnake(Direction_Left, *s.field)
  87. }
  88. s.stats.Move = 0
  89. for i := 0; i < s.maxVerificationSteps; i++ {
  90. //Read speed from client and sleep in case if user selected slow preview
  91. select {
  92. case newSpeed := <-s.speedQueue:
  93. fmt.Printf("Apply new speed: %v\n", newSpeed)
  94. if newSpeed <= 10 && newSpeed >= 0 {
  95. s.speed = newSpeed
  96. } else if newSpeed < 0 {
  97. s.speed = 0
  98. }
  99. default:
  100. }
  101. if s.speed > 0 {
  102. time.Sleep(100 / time.Duration(s.speed) * time.Millisecond)
  103. s.statsUpdateQueue <- true
  104. s.snakeUpdateQueue <- true
  105. }
  106. predictIndex, _ := inidividual.Predict(mat.NewDense(inidividual.Sizes[0], 1, s.getHeadState()))
  107. direction := Direction(predictIndex + 1)
  108. newHead := s.snake.NewHead(direction)
  109. if s.snake.selfCollision(newHead, direction) {
  110. fmt.Printf("Game over self collision\n")
  111. break
  112. } else if wallCollision(newHead, *s.field) {
  113. break
  114. } else if foodCollision(newHead, s.field.Food) {
  115. i = 0
  116. s.snake.Feed(newHead)
  117. s.field.GenerateNextFood()
  118. s.fieldUpdateQueue <- true
  119. } else {
  120. s.snake.Move(newHead)
  121. }
  122. s.stats.Move++
  123. }
  124. }
  125. // Produces input activations for neural network
  126. func (s *SnakeSimulator) getHeadState() []float64 {
  127. // Snake state
  128. headX := float64(s.snake.Points[0].X)
  129. headY := float64(s.snake.Points[0].Y)
  130. tailX := float64(s.snake.Points[len(s.snake.Points)-1].X)
  131. tailY := float64(s.snake.Points[len(s.snake.Points)-1].Y)
  132. // Field state
  133. foodX := float64(s.field.Food.X)
  134. foodY := float64(s.field.Food.Y)
  135. width := float64(s.field.Width)
  136. height := float64(s.field.Height)
  137. diag := float64(width) * math.Sqrt2 //We assume that field is always square
  138. // Output activations
  139. // Distance to walls in 4 directions
  140. lWall := headX
  141. rWall := (width - headX)
  142. tWall := headY
  143. bWall := (height - headY)
  144. // Distance to walls in 4 diagonal directions, by default is completely inactive
  145. tlWall := float64(0)
  146. trWall := float64(0)
  147. blWall := float64(0)
  148. brWall := float64(0)
  149. // Distance to food in 4 directions
  150. // By default is size of field that means that there is no activation at all
  151. lFood := float64(width)
  152. rFood := float64(width)
  153. tFood := float64(height)
  154. bFood := float64(height)
  155. // Distance to food in 4 diagonal directions
  156. // By default is size of field diagonal that means that there is no activation
  157. // at all
  158. tlFood := float64(diag)
  159. trFood := float64(diag)
  160. blFood := float64(diag)
  161. brFood := float64(diag)
  162. // Distance to tail in 4 directions
  163. tTail := float64(0)
  164. bTail := float64(0)
  165. lTail := float64(0)
  166. rTail := float64(0)
  167. // Distance to tail in 4 diagonal directions
  168. tlTail := float64(0)
  169. trTail := float64(0)
  170. blTail := float64(0)
  171. brTail := float64(0)
  172. // Diagonal distance to each wall
  173. if lWall > tWall {
  174. tlWall = float64(tWall) * math.Sqrt2
  175. } else {
  176. tlWall = float64(lWall) * math.Sqrt2
  177. }
  178. if rWall > tWall {
  179. trWall = float64(tWall) * math.Sqrt2
  180. } else {
  181. trWall = float64(rWall) * math.Sqrt2
  182. }
  183. if lWall > bWall {
  184. blWall = float64(bWall) * math.Sqrt2
  185. } else {
  186. blWall = float64(lWall) * math.Sqrt2
  187. }
  188. if rWall > bWall {
  189. blWall = float64(bWall) * math.Sqrt2
  190. } else {
  191. brWall = float64(rWall) * math.Sqrt2
  192. }
  193. // Check if food is on same vertical line with head and
  194. // choose vertical direction for activation
  195. if headX == foodX {
  196. if headY-foodY > 0 {
  197. tFood = 0
  198. } else {
  199. bFood = 0
  200. }
  201. }
  202. // Check if food is on same horizontal line with head and
  203. // choose horizontal direction for activation
  204. if headY == foodY {
  205. if headX-foodX > 0 {
  206. lFood = 0
  207. } else {
  208. rFood = 0
  209. }
  210. }
  211. //Check if food is on diagonal any of 4 ways
  212. if math.Abs(foodY-headY) == math.Abs(foodX-headX) {
  213. //Choose diagonal direction to food
  214. if foodX > headX {
  215. if foodY > headY {
  216. trFood = 0
  217. } else {
  218. brFood = 0
  219. }
  220. } else {
  221. if foodY > headY {
  222. tlFood = 0
  223. } else {
  224. blFood = 0
  225. }
  226. }
  227. }
  228. // Check if tail is on same vertical line with head and
  229. // choose vertical direction for activation
  230. if headX == tailX {
  231. if headY-tailY > 0 {
  232. tTail = headY - tailY
  233. } else {
  234. bTail = headY - tailY
  235. }
  236. }
  237. // Check if tail is on same horizontal line with head and
  238. // choose horizontal direction for activation
  239. if headY == tailY {
  240. if headX-tailX > 0 {
  241. rTail = headX - tailX
  242. } else {
  243. lTail = headX - tailX
  244. }
  245. }
  246. //Check if tail is on diagonal any of 4 ways
  247. if math.Abs(headY-tailY) == math.Abs(headX-tailX) {
  248. //Choose diagonal direction to tail
  249. if tailY > headY {
  250. if tailX > headX {
  251. trTail = diag
  252. } else {
  253. tlTail = diag
  254. }
  255. } else {
  256. if tailX > headX {
  257. brTail = diag
  258. } else {
  259. blTail = diag
  260. }
  261. }
  262. }
  263. return []float64{
  264. lWall / width,
  265. rWall / width,
  266. tWall / height,
  267. bWall / height,
  268. tlWall / diag,
  269. trWall / diag,
  270. blWall / diag,
  271. brWall / diag,
  272. (1.0 - lFood/width),
  273. (1.0 - rFood/width),
  274. (1.0 - tFood/height),
  275. (1.0 - bFood/height),
  276. (1.0 - tlFood/diag),
  277. (1.0 - trFood/diag),
  278. (1.0 - blFood/diag),
  279. (1.0 - brFood/diag),
  280. tTail / height,
  281. bTail / height,
  282. lTail / width,
  283. rTail / width,
  284. tlTail / diag,
  285. trTail / diag,
  286. blTail / diag,
  287. brTail / diag,
  288. }
  289. }
  290. // Server part
  291. // Runs gRPC server for GUI
  292. func (s *SnakeSimulator) StartServer() {
  293. go func() {
  294. grpcServer := grpc.NewServer()
  295. RegisterSnakeSimulatorServer(grpcServer, s)
  296. lis, err := net.Listen("tcp", "localhost:65002")
  297. if err != nil {
  298. fmt.Printf("Failed to listen: %v\n", err)
  299. }
  300. fmt.Printf("Listen SnakeSimulator localhost:65002\n")
  301. if err := grpcServer.Serve(lis); err != nil {
  302. fmt.Printf("Failed to serve: %v\n", err)
  303. }
  304. }()
  305. }
  306. // Steaming of Field updates
  307. func (s *SnakeSimulator) Field(_ *None, srv SnakeSimulator_FieldServer) error {
  308. ctx := srv.Context()
  309. for {
  310. select {
  311. case <-ctx.Done():
  312. return ctx.Err()
  313. default:
  314. }
  315. srv.Send(s.field)
  316. <-s.fieldUpdateQueue
  317. }
  318. }
  319. // Steaming of Snake position and length updates
  320. func (s *SnakeSimulator) Snake(_ *None, srv SnakeSimulator_SnakeServer) error {
  321. ctx := srv.Context()
  322. for {
  323. select {
  324. case <-ctx.Done():
  325. return ctx.Err()
  326. default:
  327. }
  328. srv.Send(s.snake)
  329. <-s.snakeUpdateQueue
  330. }
  331. }
  332. // Steaming of snake simulator statistic
  333. func (s *SnakeSimulator) Stats(_ *None, srv SnakeSimulator_StatsServer) error {
  334. ctx := srv.Context()
  335. for {
  336. select {
  337. case <-ctx.Done():
  338. return ctx.Err()
  339. default:
  340. }
  341. srv.Send(s.stats)
  342. <-s.statsUpdateQueue
  343. }
  344. }
  345. // Setup new speed requested from gRPC GUI client
  346. func (s *SnakeSimulator) SetSpeed(ctx context.Context, speed *Speed) (*None, error) {
  347. s.speedQueue <- speed.Speed
  348. return &None{}, nil
  349. }