snakesimulator.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. package snakesimulator
  2. import (
  3. context "context"
  4. fmt "fmt"
  5. math "math"
  6. "math/rand"
  7. "net"
  8. "sort"
  9. "sync"
  10. "time"
  11. "gonum.org/v1/gonum/mat"
  12. genetic "../genetic"
  13. neuralnetwork "../neuralnetwork"
  14. grpc "google.golang.org/grpc"
  15. )
  16. type SnakeSimulator struct {
  17. field *Field
  18. snake *Snake
  19. maxVerificationSteps int
  20. stats *Stats
  21. //GUI interface part
  22. speed uint32
  23. fieldUpdateQueue chan bool
  24. snakeUpdateQueue chan bool
  25. statsUpdateQueue chan bool
  26. speedQueue chan uint32
  27. snakeReadMutex sync.Mutex
  28. fieldReadMutex sync.Mutex
  29. }
  30. // Initializes new snake population with maximum number of verification steps
  31. func NewSnakeSimulator(maxVerificationSteps int) (s *SnakeSimulator) {
  32. s = &SnakeSimulator{
  33. field: &Field{
  34. Food: &Point{},
  35. Width: 40,
  36. Height: 40,
  37. },
  38. snake: &Snake{
  39. Points: []*Point{
  40. &Point{X: 20, Y: 20},
  41. &Point{X: 21, Y: 20},
  42. &Point{X: 22, Y: 20},
  43. },
  44. },
  45. stats: &Stats{},
  46. maxVerificationSteps: maxVerificationSteps,
  47. fieldUpdateQueue: make(chan bool, 2),
  48. snakeUpdateQueue: make(chan bool, 2),
  49. statsUpdateQueue: make(chan bool, 2),
  50. speedQueue: make(chan uint32, 1),
  51. speed: 10,
  52. }
  53. return
  54. }
  55. // Population test method
  56. // Verifies population and returns unsorted finteses for each individual
  57. func (s *SnakeSimulator) Verify(population *genetic.Population) (fitnesses []*genetic.IndividalFitness) {
  58. s.stats.Generation++
  59. s.statsUpdateQueue <- true
  60. s.field.GenerateNextFood()
  61. if s.speed > 0 {
  62. s.fieldUpdateQueue <- true
  63. }
  64. fitnesses = make([]*genetic.IndividalFitness, len(population.Networks))
  65. for index, inidividual := range population.Networks {
  66. s.stats.Individual = uint32(index)
  67. s.statsUpdateQueue <- true
  68. s.runSnake(inidividual, false)
  69. fitnesses[index] = &genetic.IndividalFitness{
  70. // Fitness: float64(s.stats.Move), //Uncomment this to decrese food impact to individual selection
  71. Fitness: float64(s.stats.Move) * float64(len(s.snake.Points)-2),
  72. Index: index,
  73. }
  74. }
  75. //This is duplication of crossbreedPopulation functionality to display best snake
  76. sort.Slice(fitnesses, func(i, j int) bool {
  77. return fitnesses[i].Fitness > fitnesses[j].Fitness //Descent order best will be on top, worst in the bottom
  78. })
  79. //Best snake showtime!
  80. s.fieldReadMutex.Lock()
  81. s.field.GenerateNextFood()
  82. s.fieldReadMutex.Unlock()
  83. s.fieldUpdateQueue <- true
  84. prevSpeed := s.speed
  85. s.speed = 5
  86. s.runSnake(population.Networks[fitnesses[0].Index], false)
  87. s.speed = prevSpeed
  88. return
  89. }
  90. func (s *SnakeSimulator) runSnake(inidividual *neuralnetwork.NeuralNetwork, randomStart bool) {
  91. s.snakeReadMutex.Lock()
  92. if randomStart {
  93. rand.Seed(time.Now().UnixNano())
  94. s.snake = NewSnake(Direction(rand.Uint32()%4), *s.field)
  95. } else {
  96. s.snake = NewSnake(Direction_Left, *s.field)
  97. }
  98. s.snakeReadMutex.Unlock()
  99. s.stats.Move = 0
  100. for i := 0; i < s.maxVerificationSteps; i++ {
  101. //Read speed from client and sleep in case if user selected slow preview
  102. select {
  103. case newSpeed := <-s.speedQueue:
  104. fmt.Printf("Apply new speed: %v\n", newSpeed)
  105. if newSpeed <= 10 && newSpeed >= 0 {
  106. s.speed = newSpeed
  107. } else if newSpeed < 0 {
  108. s.speed = 0
  109. }
  110. default:
  111. }
  112. if s.speed > 0 {
  113. time.Sleep(100 / time.Duration(s.speed) * time.Millisecond)
  114. s.statsUpdateQueue <- true
  115. s.snakeUpdateQueue <- true
  116. }
  117. predictIndex, _ := inidividual.Predict(mat.NewDense(inidividual.Sizes[0], 1, s.getHeadState()))
  118. direction := Direction(predictIndex + 1)
  119. newHead := s.snake.NewHead(direction)
  120. if s.snake.selfCollision(newHead, direction) {
  121. fmt.Printf("Game over self collision\n")
  122. break
  123. } else if wallCollision(newHead, *s.field) {
  124. break
  125. } else if foodCollision(newHead, s.field.Food) {
  126. i = 0
  127. s.snakeReadMutex.Lock()
  128. s.snake.Feed(newHead)
  129. s.snakeReadMutex.Unlock()
  130. s.fieldReadMutex.Lock()
  131. s.field.GenerateNextFood()
  132. s.fieldReadMutex.Unlock()
  133. if s.speed > 0 {
  134. s.fieldUpdateQueue <- true
  135. }
  136. } else {
  137. s.snakeReadMutex.Lock()
  138. s.snake.Move(newHead)
  139. s.snakeReadMutex.Unlock()
  140. }
  141. s.stats.Move++
  142. }
  143. }
  144. // Produces input activations for neural network
  145. func (s *SnakeSimulator) getHeadState() []float64 {
  146. // Snake state
  147. headX := float64(s.snake.Points[0].X)
  148. headY := float64(s.snake.Points[0].Y)
  149. tailX := float64(s.snake.Points[len(s.snake.Points)-1].X)
  150. tailY := float64(s.snake.Points[len(s.snake.Points)-1].Y)
  151. // Field state
  152. foodX := float64(s.field.Food.X)
  153. foodY := float64(s.field.Food.Y)
  154. width := float64(s.field.Width)
  155. height := float64(s.field.Height)
  156. diag := float64(width) * math.Sqrt2 //We assume that field is always square
  157. // Output activations
  158. // Distance to walls in 4 directions
  159. lWall := headX
  160. rWall := (width - headX)
  161. tWall := headY
  162. bWall := (height - headY)
  163. // Distance to walls in 4 diagonal directions, by default is completely inactive
  164. tlWall := float64(0)
  165. trWall := float64(0)
  166. blWall := float64(0)
  167. brWall := float64(0)
  168. // Distance to food in 4 directions
  169. // By default is size of field that means that there is no activation at all
  170. lFood := float64(width)
  171. rFood := float64(width)
  172. tFood := float64(height)
  173. bFood := float64(height)
  174. // Distance to food in 4 diagonal directions
  175. // By default is size of field diagonal that means that there is no activation
  176. // at all
  177. tlFood := float64(diag)
  178. trFood := float64(diag)
  179. blFood := float64(diag)
  180. brFood := float64(diag)
  181. // Distance to tail in 4 directions
  182. tTail := float64(0)
  183. bTail := float64(0)
  184. lTail := float64(0)
  185. rTail := float64(0)
  186. // Distance to tail in 4 diagonal directions
  187. tlTail := float64(0)
  188. trTail := float64(0)
  189. blTail := float64(0)
  190. brTail := float64(0)
  191. // Diagonal distance to each wall
  192. if lWall > tWall {
  193. tlWall = float64(tWall) * math.Sqrt2
  194. } else {
  195. tlWall = float64(lWall) * math.Sqrt2
  196. }
  197. if rWall > tWall {
  198. trWall = float64(tWall) * math.Sqrt2
  199. } else {
  200. trWall = float64(rWall) * math.Sqrt2
  201. }
  202. if lWall > bWall {
  203. blWall = float64(bWall) * math.Sqrt2
  204. } else {
  205. blWall = float64(lWall) * math.Sqrt2
  206. }
  207. if rWall > bWall {
  208. blWall = float64(bWall) * math.Sqrt2
  209. } else {
  210. brWall = float64(rWall) * math.Sqrt2
  211. }
  212. // Check if food is on same vertical line with head and
  213. // choose vertical direction for activation
  214. if headX == foodX {
  215. if headY-foodY > 0 {
  216. tFood = 0
  217. } else {
  218. bFood = 0
  219. }
  220. }
  221. // Check if food is on same horizontal line with head and
  222. // choose horizontal direction for activation
  223. if headY == foodY {
  224. if headX-foodX > 0 {
  225. lFood = 0
  226. } else {
  227. rFood = 0
  228. }
  229. }
  230. //Check if food is on diagonal any of 4 ways
  231. if math.Abs(foodY-headY) == math.Abs(foodX-headX) {
  232. //Choose diagonal direction to food
  233. if foodX > headX {
  234. if foodY > headY {
  235. trFood = 0
  236. } else {
  237. brFood = 0
  238. }
  239. } else {
  240. if foodY > headY {
  241. tlFood = 0
  242. } else {
  243. blFood = 0
  244. }
  245. }
  246. }
  247. // Check if tail is on same vertical line with head and
  248. // choose vertical direction for activation
  249. if headX == tailX {
  250. if headY-tailY > 0 {
  251. tTail = headY - tailY
  252. } else {
  253. bTail = headY - tailY
  254. }
  255. }
  256. // Check if tail is on same horizontal line with head and
  257. // choose horizontal direction for activation
  258. if headY == tailY {
  259. if headX-tailX > 0 {
  260. rTail = headX - tailX
  261. } else {
  262. lTail = headX - tailX
  263. }
  264. }
  265. //Check if tail is on diagonal any of 4 ways
  266. if math.Abs(headY-tailY) == math.Abs(headX-tailX) {
  267. //Choose diagonal direction to tail
  268. if tailY > headY {
  269. if tailX > headX {
  270. trTail = diag
  271. } else {
  272. tlTail = diag
  273. }
  274. } else {
  275. if tailX > headX {
  276. brTail = diag
  277. } else {
  278. blTail = diag
  279. }
  280. }
  281. }
  282. return []float64{
  283. lWall / width,
  284. rWall / width,
  285. tWall / height,
  286. bWall / height,
  287. tlWall / diag,
  288. trWall / diag,
  289. blWall / diag,
  290. brWall / diag,
  291. (1.0 - lFood/width),
  292. (1.0 - rFood/width),
  293. (1.0 - tFood/height),
  294. (1.0 - bFood/height),
  295. (1.0 - tlFood/diag),
  296. (1.0 - trFood/diag),
  297. (1.0 - blFood/diag),
  298. (1.0 - brFood/diag),
  299. tTail / height,
  300. bTail / height,
  301. lTail / width,
  302. rTail / width,
  303. tlTail / diag,
  304. trTail / diag,
  305. blTail / diag,
  306. brTail / diag,
  307. }
  308. }
  309. // Server part
  310. // Runs gRPC server for GUI
  311. func (s *SnakeSimulator) StartServer() {
  312. go func() {
  313. grpcServer := grpc.NewServer()
  314. RegisterSnakeSimulatorServer(grpcServer, s)
  315. lis, err := net.Listen("tcp", "localhost:65002")
  316. if err != nil {
  317. fmt.Printf("Failed to listen: %v\n", err)
  318. }
  319. fmt.Printf("Listen SnakeSimulator localhost:65002\n")
  320. if err := grpcServer.Serve(lis); err != nil {
  321. fmt.Printf("Failed to serve: %v\n", err)
  322. }
  323. }()
  324. }
  325. // Steaming of Field updates
  326. func (s *SnakeSimulator) Field(_ *None, srv SnakeSimulator_FieldServer) error {
  327. ctx := srv.Context()
  328. for {
  329. select {
  330. case <-ctx.Done():
  331. return ctx.Err()
  332. default:
  333. }
  334. s.snakeReadMutex.Lock()
  335. srv.Send(s.field)
  336. s.snakeReadMutex.Unlock()
  337. <-s.fieldUpdateQueue
  338. }
  339. }
  340. // Steaming of Snake position and length updates
  341. func (s *SnakeSimulator) Snake(_ *None, srv SnakeSimulator_SnakeServer) error {
  342. ctx := srv.Context()
  343. for {
  344. select {
  345. case <-ctx.Done():
  346. return ctx.Err()
  347. default:
  348. }
  349. srv.Send(s.snake)
  350. <-s.snakeUpdateQueue
  351. }
  352. }
  353. // Steaming of snake simulator statistic
  354. func (s *SnakeSimulator) Stats(_ *None, srv SnakeSimulator_StatsServer) error {
  355. ctx := srv.Context()
  356. for {
  357. select {
  358. case <-ctx.Done():
  359. return ctx.Err()
  360. default:
  361. }
  362. s.fieldReadMutex.Lock()
  363. srv.Send(s.stats)
  364. s.fieldReadMutex.Unlock()
  365. <-s.statsUpdateQueue
  366. }
  367. }
  368. // Setup new speed requested from gRPC GUI client
  369. func (s *SnakeSimulator) SetSpeed(ctx context.Context, speed *Speed) (*None, error) {
  370. s.speedQueue <- speed.Speed
  371. return &None{}, nil
  372. }