snakesimulator.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. package snakesimulator
  2. import (
  3. context "context"
  4. fmt "fmt"
  5. math "math"
  6. "math/rand"
  7. "net"
  8. "sort"
  9. "sync"
  10. "time"
  11. "gonum.org/v1/gonum/mat"
  12. genetic "../genetic"
  13. neuralnetwork "../neuralnetwork"
  14. remotecontrol "../remotecontrol"
  15. grpc "google.golang.org/grpc"
  16. )
  17. type SnakeSimulator struct {
  18. field *Field
  19. snake *Snake
  20. maxVerificationSteps int
  21. stats *Stats
  22. remoteControl *remotecontrol.RemoteControl
  23. //GUI interface part
  24. speed uint32
  25. fieldUpdateQueue chan bool
  26. snakeUpdateQueue chan bool
  27. statsUpdateQueue chan bool
  28. speedQueue chan uint32
  29. snakeReadMutex sync.Mutex
  30. fieldReadMutex sync.Mutex
  31. }
  32. // Initializes new snake population with maximum number of verification steps
  33. func NewSnakeSimulator(maxVerificationSteps int) (s *SnakeSimulator) {
  34. s = &SnakeSimulator{
  35. field: &Field{
  36. Food: &Point{},
  37. Width: 40,
  38. Height: 40,
  39. },
  40. snake: &Snake{
  41. Points: []*Point{
  42. &Point{X: 20, Y: 20},
  43. &Point{X: 21, Y: 20},
  44. &Point{X: 22, Y: 20},
  45. },
  46. },
  47. stats: &Stats{},
  48. maxVerificationSteps: maxVerificationSteps,
  49. fieldUpdateQueue: make(chan bool, 2),
  50. snakeUpdateQueue: make(chan bool, 2),
  51. statsUpdateQueue: make(chan bool, 2),
  52. speedQueue: make(chan uint32, 1),
  53. speed: 10,
  54. remoteControl: remotecontrol.NewRemoteControl(),
  55. }
  56. return
  57. }
  58. // Population test method
  59. // Verifies population and returns unsorted finteses for each individual
  60. func (s *SnakeSimulator) Verify(population *genetic.Population) (fitnesses []*genetic.IndividalFitness) {
  61. s.remoteControl.Init(population.Networks[0])
  62. s.stats.Generation++
  63. s.statsUpdateQueue <- true
  64. s.field.GenerateNextFood()
  65. if s.speed > 0 {
  66. s.fieldUpdateQueue <- true
  67. }
  68. fitnesses = make([]*genetic.IndividalFitness, len(population.Networks))
  69. for index, inidividual := range population.Networks {
  70. s.stats.Individual = uint32(index)
  71. s.statsUpdateQueue <- true
  72. s.runSnake(inidividual, false)
  73. fitnesses[index] = &genetic.IndividalFitness{
  74. // Fitness: float64(s.stats.Move), //Uncomment this to decrese food impact to individual selection
  75. Fitness: float64(s.stats.Move) * float64(len(s.snake.Points)-2),
  76. Index: index,
  77. }
  78. }
  79. //This is duplication of crossbreedPopulation functionality to display best snake
  80. sort.Slice(fitnesses, func(i, j int) bool {
  81. return fitnesses[i].Fitness > fitnesses[j].Fitness //Descent order best will be on top, worst in the bottom
  82. })
  83. //Best snake showtime!
  84. s.fieldReadMutex.Lock()
  85. s.field.GenerateNextFood()
  86. s.fieldReadMutex.Unlock()
  87. s.fieldUpdateQueue <- true
  88. prevSpeed := s.speed
  89. s.speed = 5
  90. population.Networks[fitnesses[0].Index].SetStateWatcher(s.remoteControl)
  91. s.runSnake(population.Networks[fitnesses[0].Index], false)
  92. population.Networks[fitnesses[0].Index].SetStateWatcher(nil)
  93. s.speed = prevSpeed
  94. return
  95. }
  96. func (s *SnakeSimulator) runSnake(inidividual *neuralnetwork.NeuralNetwork, randomStart bool) {
  97. s.snakeReadMutex.Lock()
  98. if randomStart {
  99. rand.Seed(time.Now().UnixNano())
  100. s.snake = NewSnake(Direction(rand.Uint32()%4), *s.field)
  101. } else {
  102. s.snake = NewSnake(Direction_Left, *s.field)
  103. }
  104. s.snakeReadMutex.Unlock()
  105. s.stats.Move = 0
  106. for i := 0; i < s.maxVerificationSteps; i++ {
  107. //Read speed from client and sleep in case if user selected slow preview
  108. select {
  109. case newSpeed := <-s.speedQueue:
  110. fmt.Printf("Apply new speed: %v\n", newSpeed)
  111. if newSpeed <= 10 && newSpeed >= 0 {
  112. s.speed = newSpeed
  113. } else if newSpeed < 0 {
  114. s.speed = 0
  115. }
  116. default:
  117. }
  118. if s.speed > 0 {
  119. time.Sleep(100 / time.Duration(s.speed) * time.Millisecond)
  120. s.statsUpdateQueue <- true
  121. s.snakeUpdateQueue <- true
  122. }
  123. predictIndex, _ := inidividual.Predict(mat.NewDense(inidividual.Sizes[0], 1, s.getHeadState()))
  124. direction := Direction(predictIndex + 1)
  125. newHead := s.snake.NewHead(direction)
  126. if s.snake.selfCollision(newHead, direction) {
  127. fmt.Printf("Game over self collision\n")
  128. break
  129. } else if wallCollision(newHead, *s.field) {
  130. break
  131. } else if foodCollision(newHead, s.field.Food) {
  132. i = 0
  133. s.snakeReadMutex.Lock()
  134. s.snake.Feed(newHead)
  135. s.snakeReadMutex.Unlock()
  136. s.fieldReadMutex.Lock()
  137. s.field.GenerateNextFood()
  138. s.fieldReadMutex.Unlock()
  139. if s.speed > 0 {
  140. s.fieldUpdateQueue <- true
  141. }
  142. } else {
  143. s.snakeReadMutex.Lock()
  144. s.snake.Move(newHead)
  145. s.snakeReadMutex.Unlock()
  146. }
  147. s.stats.Move++
  148. }
  149. }
  150. // Produces input activations for neural network
  151. func (s *SnakeSimulator) getHeadState() []float64 {
  152. // Snake state
  153. headX := float64(s.snake.Points[0].X)
  154. headY := float64(s.snake.Points[0].Y)
  155. tailX := float64(s.snake.Points[len(s.snake.Points)-1].X)
  156. tailY := float64(s.snake.Points[len(s.snake.Points)-1].Y)
  157. // Field state
  158. foodX := float64(s.field.Food.X)
  159. foodY := float64(s.field.Food.Y)
  160. width := float64(s.field.Width)
  161. height := float64(s.field.Height)
  162. diag := float64(width) * math.Sqrt2 //We assume that field is always square
  163. // Output activations
  164. // Distance to walls in 4 directions
  165. lWall := headX
  166. rWall := (width - headX)
  167. tWall := headY
  168. bWall := (height - headY)
  169. // Distance to walls in 4 diagonal directions, by default is completely inactive
  170. tlWall := float64(0)
  171. trWall := float64(0)
  172. blWall := float64(0)
  173. brWall := float64(0)
  174. // Distance to food in 4 directions
  175. // By default is size of field that means that there is no activation at all
  176. lFood := float64(width)
  177. rFood := float64(width)
  178. tFood := float64(height)
  179. bFood := float64(height)
  180. // Distance to food in 4 diagonal directions
  181. // By default is size of field diagonal that means that there is no activation
  182. // at all
  183. tlFood := float64(diag)
  184. trFood := float64(diag)
  185. blFood := float64(diag)
  186. brFood := float64(diag)
  187. // Distance to tail in 4 directions
  188. tTail := float64(0)
  189. bTail := float64(0)
  190. lTail := float64(0)
  191. rTail := float64(0)
  192. // Distance to tail in 4 diagonal directions
  193. tlTail := float64(0)
  194. trTail := float64(0)
  195. blTail := float64(0)
  196. brTail := float64(0)
  197. // Diagonal distance to each wall
  198. if lWall > tWall {
  199. tlWall = float64(tWall) * math.Sqrt2
  200. } else {
  201. tlWall = float64(lWall) * math.Sqrt2
  202. }
  203. if rWall > tWall {
  204. trWall = float64(tWall) * math.Sqrt2
  205. } else {
  206. trWall = float64(rWall) * math.Sqrt2
  207. }
  208. if lWall > bWall {
  209. blWall = float64(bWall) * math.Sqrt2
  210. } else {
  211. blWall = float64(lWall) * math.Sqrt2
  212. }
  213. if rWall > bWall {
  214. blWall = float64(bWall) * math.Sqrt2
  215. } else {
  216. brWall = float64(rWall) * math.Sqrt2
  217. }
  218. // Check if food is on same vertical line with head and
  219. // choose vertical direction for activation
  220. if headX == foodX {
  221. if headY-foodY > 0 {
  222. tFood = 0
  223. } else {
  224. bFood = 0
  225. }
  226. }
  227. // Check if food is on same horizontal line with head and
  228. // choose horizontal direction for activation
  229. if headY == foodY {
  230. if headX-foodX > 0 {
  231. lFood = 0
  232. } else {
  233. rFood = 0
  234. }
  235. }
  236. //Check if food is on diagonal any of 4 ways
  237. if math.Abs(foodY-headY) == math.Abs(foodX-headX) {
  238. //Choose diagonal direction to food
  239. if foodX > headX {
  240. if foodY > headY {
  241. trFood = 0
  242. } else {
  243. brFood = 0
  244. }
  245. } else {
  246. if foodY > headY {
  247. tlFood = 0
  248. } else {
  249. blFood = 0
  250. }
  251. }
  252. }
  253. // Check if tail is on same vertical line with head and
  254. // choose vertical direction for activation
  255. if headX == tailX {
  256. if headY-tailY > 0 {
  257. tTail = headY - tailY
  258. } else {
  259. bTail = headY - tailY
  260. }
  261. }
  262. // Check if tail is on same horizontal line with head and
  263. // choose horizontal direction for activation
  264. if headY == tailY {
  265. if headX-tailX > 0 {
  266. rTail = headX - tailX
  267. } else {
  268. lTail = headX - tailX
  269. }
  270. }
  271. //Check if tail is on diagonal any of 4 ways
  272. if math.Abs(headY-tailY) == math.Abs(headX-tailX) {
  273. //Choose diagonal direction to tail
  274. if tailY > headY {
  275. if tailX > headX {
  276. trTail = diag
  277. } else {
  278. tlTail = diag
  279. }
  280. } else {
  281. if tailX > headX {
  282. brTail = diag
  283. } else {
  284. blTail = diag
  285. }
  286. }
  287. }
  288. return []float64{
  289. lWall / width,
  290. rWall / width,
  291. tWall / height,
  292. bWall / height,
  293. tlWall / diag,
  294. trWall / diag,
  295. blWall / diag,
  296. brWall / diag,
  297. (1.0 - lFood/width),
  298. (1.0 - rFood/width),
  299. (1.0 - tFood/height),
  300. (1.0 - bFood/height),
  301. (1.0 - tlFood/diag),
  302. (1.0 - trFood/diag),
  303. (1.0 - blFood/diag),
  304. (1.0 - brFood/diag),
  305. tTail / height,
  306. bTail / height,
  307. lTail / width,
  308. rTail / width,
  309. tlTail / diag,
  310. trTail / diag,
  311. blTail / diag,
  312. brTail / diag,
  313. }
  314. }
  315. // Server part
  316. // Runs gRPC server for GUI
  317. func (s *SnakeSimulator) StartServer() {
  318. go func() {
  319. grpcServer := grpc.NewServer()
  320. RegisterSnakeSimulatorServer(grpcServer, s)
  321. lis, err := net.Listen("tcp", "localhost:65002")
  322. if err != nil {
  323. fmt.Printf("Failed to listen: %v\n", err)
  324. }
  325. fmt.Printf("Listen SnakeSimulator localhost:65002\n")
  326. if err := grpcServer.Serve(lis); err != nil {
  327. fmt.Printf("Failed to serve: %v\n", err)
  328. }
  329. }()
  330. go s.remoteControl.Run()
  331. }
  332. // Steaming of Field updates
  333. func (s *SnakeSimulator) Field(_ *None, srv SnakeSimulator_FieldServer) error {
  334. ctx := srv.Context()
  335. for {
  336. select {
  337. case <-ctx.Done():
  338. return ctx.Err()
  339. default:
  340. }
  341. s.snakeReadMutex.Lock()
  342. srv.Send(s.field)
  343. s.snakeReadMutex.Unlock()
  344. <-s.fieldUpdateQueue
  345. }
  346. }
  347. // Steaming of Snake position and length updates
  348. func (s *SnakeSimulator) Snake(_ *None, srv SnakeSimulator_SnakeServer) error {
  349. ctx := srv.Context()
  350. for {
  351. select {
  352. case <-ctx.Done():
  353. return ctx.Err()
  354. default:
  355. }
  356. srv.Send(s.snake)
  357. <-s.snakeUpdateQueue
  358. }
  359. }
  360. // Steaming of snake simulator statistic
  361. func (s *SnakeSimulator) Stats(_ *None, srv SnakeSimulator_StatsServer) error {
  362. ctx := srv.Context()
  363. for {
  364. select {
  365. case <-ctx.Done():
  366. return ctx.Err()
  367. default:
  368. }
  369. s.fieldReadMutex.Lock()
  370. srv.Send(s.stats)
  371. s.fieldReadMutex.Unlock()
  372. <-s.statsUpdateQueue
  373. }
  374. }
  375. // Setup new speed requested from gRPC GUI client
  376. func (s *SnakeSimulator) SetSpeed(ctx context.Context, speed *Speed) (*None, error) {
  377. s.speedQueue <- speed.Speed
  378. return &None{}, nil
  379. }