snakesimulator.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. package snakesimulator
  2. import (
  3. context "context"
  4. fmt "fmt"
  5. math "math"
  6. "math/rand"
  7. "net"
  8. "sort"
  9. "time"
  10. "gonum.org/v1/gonum/mat"
  11. genetic "../genetic"
  12. neuralnetwork "../neuralnetwork"
  13. grpc "google.golang.org/grpc"
  14. )
  15. type SnakeSimulator struct {
  16. field *Field
  17. snake *Snake
  18. maxVerificationSteps int
  19. stats *Stats
  20. //GUI interface part
  21. speed uint32
  22. fieldUpdateQueue chan bool
  23. snakeUpdateQueue chan bool
  24. statsUpdateQueue chan bool
  25. speedQueue chan uint32
  26. }
  27. // Initializes new snake population with maximum number of verification steps
  28. func NewSnakeSimulator(maxVerificationSteps int) (s *SnakeSimulator) {
  29. s = &SnakeSimulator{
  30. field: &Field{
  31. Food: &Point{},
  32. Width: 40,
  33. Height: 40,
  34. },
  35. snake: &Snake{
  36. Points: []*Point{
  37. &Point{X: 20, Y: 20},
  38. &Point{X: 21, Y: 20},
  39. &Point{X: 22, Y: 20},
  40. },
  41. },
  42. stats: &Stats{},
  43. maxVerificationSteps: maxVerificationSteps,
  44. fieldUpdateQueue: make(chan bool, 2),
  45. snakeUpdateQueue: make(chan bool, 2),
  46. statsUpdateQueue: make(chan bool, 2),
  47. speedQueue: make(chan uint32, 1),
  48. speed: 10,
  49. }
  50. return
  51. }
  52. // Population test method
  53. // Verifies population and returns unsorted finteses for each individual
  54. func (s *SnakeSimulator) Verify(population *genetic.Population) (fitnesses []*genetic.IndividalFitness) {
  55. s.stats.Generation++
  56. s.statsUpdateQueue <- true
  57. s.field.GenerateNextFood()
  58. if s.speed > 0 {
  59. s.fieldUpdateQueue <- true
  60. }
  61. fitnesses = make([]*genetic.IndividalFitness, len(population.Networks))
  62. for index, inidividual := range population.Networks {
  63. s.stats.Individual = uint32(index)
  64. s.statsUpdateQueue <- true
  65. s.runSnake(inidividual, false)
  66. fitnesses[index] = &genetic.IndividalFitness{
  67. // Fitness: float64(s.stats.Move), //Uncomment this to decrese food impact to individual selection
  68. Fitness: float64(s.stats.Move) * float64(len(s.snake.Points)-2),
  69. Index: index,
  70. }
  71. }
  72. //This is duplication of crossbreedPopulation functionality to display best snake
  73. sort.Slice(fitnesses, func(i, j int) bool {
  74. return fitnesses[i].Fitness > fitnesses[j].Fitness //Descent order best will be on top, worst in the bottom
  75. })
  76. //Best snake showtime!
  77. s.field.GenerateNextFood()
  78. s.fieldUpdateQueue <- true
  79. prevSpeed := s.speed
  80. s.speed = 5
  81. s.runSnake(population.Networks[fitnesses[0].Index], false)
  82. s.speed = prevSpeed
  83. return
  84. }
  85. func (s *SnakeSimulator) runSnake(inidividual *neuralnetwork.NeuralNetwork, randomStart bool) {
  86. if randomStart {
  87. rand.Seed(time.Now().UnixNano())
  88. s.snake = NewSnake(Direction(rand.Uint32()%4), *s.field)
  89. } else {
  90. s.snake = NewSnake(Direction_Left, *s.field)
  91. }
  92. s.stats.Move = 0
  93. for i := 0; i < s.maxVerificationSteps; i++ {
  94. //Read speed from client and sleep in case if user selected slow preview
  95. select {
  96. case newSpeed := <-s.speedQueue:
  97. fmt.Printf("Apply new speed: %v\n", newSpeed)
  98. if newSpeed <= 10 && newSpeed >= 0 {
  99. s.speed = newSpeed
  100. } else if newSpeed < 0 {
  101. s.speed = 0
  102. }
  103. default:
  104. }
  105. if s.speed > 0 {
  106. time.Sleep(100 / time.Duration(s.speed) * time.Millisecond)
  107. s.statsUpdateQueue <- true
  108. s.snakeUpdateQueue <- true
  109. }
  110. predictIndex, _ := inidividual.Predict(mat.NewDense(inidividual.Sizes[0], 1, s.getHeadState()))
  111. direction := Direction(predictIndex + 1)
  112. newHead := s.snake.NewHead(direction)
  113. if s.snake.selfCollision(newHead, direction) {
  114. fmt.Printf("Game over self collision\n")
  115. break
  116. } else if wallCollision(newHead, *s.field) {
  117. break
  118. } else if foodCollision(newHead, s.field.Food) {
  119. i = 0
  120. s.snake.Feed(newHead)
  121. s.field.GenerateNextFood()
  122. if s.speed > 0 {
  123. s.fieldUpdateQueue <- true
  124. }
  125. } else {
  126. s.snake.Move(newHead)
  127. }
  128. s.stats.Move++
  129. }
  130. }
  131. // Produces input activations for neural network
  132. func (s *SnakeSimulator) getHeadState() []float64 {
  133. // Snake state
  134. headX := float64(s.snake.Points[0].X)
  135. headY := float64(s.snake.Points[0].Y)
  136. tailX := float64(s.snake.Points[len(s.snake.Points)-1].X)
  137. tailY := float64(s.snake.Points[len(s.snake.Points)-1].Y)
  138. // Field state
  139. foodX := float64(s.field.Food.X)
  140. foodY := float64(s.field.Food.Y)
  141. width := float64(s.field.Width)
  142. height := float64(s.field.Height)
  143. diag := float64(width) * math.Sqrt2 //We assume that field is always square
  144. // Output activations
  145. // Distance to walls in 4 directions
  146. lWall := headX
  147. rWall := (width - headX)
  148. tWall := headY
  149. bWall := (height - headY)
  150. // Distance to walls in 4 diagonal directions, by default is completely inactive
  151. tlWall := float64(0)
  152. trWall := float64(0)
  153. blWall := float64(0)
  154. brWall := float64(0)
  155. // Distance to food in 4 directions
  156. // By default is size of field that means that there is no activation at all
  157. lFood := float64(width)
  158. rFood := float64(width)
  159. tFood := float64(height)
  160. bFood := float64(height)
  161. // Distance to food in 4 diagonal directions
  162. // By default is size of field diagonal that means that there is no activation
  163. // at all
  164. tlFood := float64(diag)
  165. trFood := float64(diag)
  166. blFood := float64(diag)
  167. brFood := float64(diag)
  168. // Distance to tail in 4 directions
  169. tTail := float64(0)
  170. bTail := float64(0)
  171. lTail := float64(0)
  172. rTail := float64(0)
  173. // Distance to tail in 4 diagonal directions
  174. tlTail := float64(0)
  175. trTail := float64(0)
  176. blTail := float64(0)
  177. brTail := float64(0)
  178. // Diagonal distance to each wall
  179. if lWall > tWall {
  180. tlWall = float64(tWall) * math.Sqrt2
  181. } else {
  182. tlWall = float64(lWall) * math.Sqrt2
  183. }
  184. if rWall > tWall {
  185. trWall = float64(tWall) * math.Sqrt2
  186. } else {
  187. trWall = float64(rWall) * math.Sqrt2
  188. }
  189. if lWall > bWall {
  190. blWall = float64(bWall) * math.Sqrt2
  191. } else {
  192. blWall = float64(lWall) * math.Sqrt2
  193. }
  194. if rWall > bWall {
  195. blWall = float64(bWall) * math.Sqrt2
  196. } else {
  197. brWall = float64(rWall) * math.Sqrt2
  198. }
  199. // Check if food is on same vertical line with head and
  200. // choose vertical direction for activation
  201. if headX == foodX {
  202. if headY-foodY > 0 {
  203. tFood = 0
  204. } else {
  205. bFood = 0
  206. }
  207. }
  208. // Check if food is on same horizontal line with head and
  209. // choose horizontal direction for activation
  210. if headY == foodY {
  211. if headX-foodX > 0 {
  212. lFood = 0
  213. } else {
  214. rFood = 0
  215. }
  216. }
  217. //Check if food is on diagonal any of 4 ways
  218. if math.Abs(foodY-headY) == math.Abs(foodX-headX) {
  219. //Choose diagonal direction to food
  220. if foodX > headX {
  221. if foodY > headY {
  222. trFood = 0
  223. } else {
  224. brFood = 0
  225. }
  226. } else {
  227. if foodY > headY {
  228. tlFood = 0
  229. } else {
  230. blFood = 0
  231. }
  232. }
  233. }
  234. // Check if tail is on same vertical line with head and
  235. // choose vertical direction for activation
  236. if headX == tailX {
  237. if headY-tailY > 0 {
  238. tTail = headY - tailY
  239. } else {
  240. bTail = headY - tailY
  241. }
  242. }
  243. // Check if tail is on same horizontal line with head and
  244. // choose horizontal direction for activation
  245. if headY == tailY {
  246. if headX-tailX > 0 {
  247. rTail = headX - tailX
  248. } else {
  249. lTail = headX - tailX
  250. }
  251. }
  252. //Check if tail is on diagonal any of 4 ways
  253. if math.Abs(headY-tailY) == math.Abs(headX-tailX) {
  254. //Choose diagonal direction to tail
  255. if tailY > headY {
  256. if tailX > headX {
  257. trTail = diag
  258. } else {
  259. tlTail = diag
  260. }
  261. } else {
  262. if tailX > headX {
  263. brTail = diag
  264. } else {
  265. blTail = diag
  266. }
  267. }
  268. }
  269. return []float64{
  270. lWall / width,
  271. rWall / width,
  272. tWall / height,
  273. bWall / height,
  274. tlWall / diag,
  275. trWall / diag,
  276. blWall / diag,
  277. brWall / diag,
  278. (1.0 - lFood/width),
  279. (1.0 - rFood/width),
  280. (1.0 - tFood/height),
  281. (1.0 - bFood/height),
  282. (1.0 - tlFood/diag),
  283. (1.0 - trFood/diag),
  284. (1.0 - blFood/diag),
  285. (1.0 - brFood/diag),
  286. tTail / height,
  287. bTail / height,
  288. lTail / width,
  289. rTail / width,
  290. tlTail / diag,
  291. trTail / diag,
  292. blTail / diag,
  293. brTail / diag,
  294. }
  295. }
  296. // Server part
  297. // Runs gRPC server for GUI
  298. func (s *SnakeSimulator) StartServer() {
  299. go func() {
  300. grpcServer := grpc.NewServer()
  301. RegisterSnakeSimulatorServer(grpcServer, s)
  302. lis, err := net.Listen("tcp", "localhost:65002")
  303. if err != nil {
  304. fmt.Printf("Failed to listen: %v\n", err)
  305. }
  306. fmt.Printf("Listen SnakeSimulator localhost:65002\n")
  307. if err := grpcServer.Serve(lis); err != nil {
  308. fmt.Printf("Failed to serve: %v\n", err)
  309. }
  310. }()
  311. }
  312. // Steaming of Field updates
  313. func (s *SnakeSimulator) Field(_ *None, srv SnakeSimulator_FieldServer) error {
  314. ctx := srv.Context()
  315. for {
  316. select {
  317. case <-ctx.Done():
  318. return ctx.Err()
  319. default:
  320. }
  321. srv.Send(s.field)
  322. <-s.fieldUpdateQueue
  323. }
  324. }
  325. // Steaming of Snake position and length updates
  326. func (s *SnakeSimulator) Snake(_ *None, srv SnakeSimulator_SnakeServer) error {
  327. ctx := srv.Context()
  328. for {
  329. select {
  330. case <-ctx.Done():
  331. return ctx.Err()
  332. default:
  333. }
  334. srv.Send(s.snake)
  335. <-s.snakeUpdateQueue
  336. }
  337. }
  338. // Steaming of snake simulator statistic
  339. func (s *SnakeSimulator) Stats(_ *None, srv SnakeSimulator_StatsServer) error {
  340. ctx := srv.Context()
  341. for {
  342. select {
  343. case <-ctx.Done():
  344. return ctx.Err()
  345. default:
  346. }
  347. srv.Send(s.stats)
  348. <-s.statsUpdateQueue
  349. }
  350. }
  351. // Setup new speed requested from gRPC GUI client
  352. func (s *SnakeSimulator) SetSpeed(ctx context.Context, speed *Speed) (*None, error) {
  353. s.speedQueue <- speed.Speed
  354. return &None{}, nil
  355. }