Quellcode durchsuchen

Implement basic snake trainng

- Not sure if calculated resuls are correct
Alexey Edelev vor 5 Jahren
Ursprung
Commit
c227a47459

+ 3 - 3
neuralnetwork/genetic/genetic.go

@@ -46,12 +46,12 @@ func (p *Population) NaturalSelection(generationCount int) {
 
 func (p *Population) crossbreedPopulation(results []*IndividalResult) {
 	sort.Slice(results, func(i, j int) bool {
-		return results[i].result < results[j].result
+		return results[i].Result < results[j].Result
 	})
 
 	for i := 1; i < p.populationSize; i += 2 {
-		firstParent := results[i].index
-		secondParent := results[i-1].index
+		firstParent := results[i].Index
+		secondParent := results[i-1].Index
 		crossbreed(p.Networks[firstParent], p.Networks[secondParent])
 		p.mutagen.Mutate(p.Networks[firstParent])
 		p.mutagen.Mutate(p.Networks[secondParent])

+ 2 - 2
neuralnetwork/genetic/interface.go

@@ -3,8 +3,8 @@ package genetic
 import neuralnetwork "../neuralnetwork"
 
 type IndividalResult struct {
-	result float64
-	index  int
+	Result float64
+	Index  int
 }
 
 type PopulationVerifier interface {

+ 20 - 15
neuralnetwork/main.go

@@ -1,25 +1,30 @@
 package main
 
 import (
-	neuralnetwork "./neuralnetwork"
-	remotecontrol "./remotecontrol"
+	genetic "./genetic"
+	mutagen "./genetic/mutagen"
 	snakesimulator "./snakesimulator"
 )
 
 func main() {
 	s := snakesimulator.NewSnakeSimulator()
-	s.Run()
-	// genetic.NewPopulation(nil, mutagen.NewDummyMutagen(50), 200, []int{13, 8, 12, 3})
-	sizes := []int{13, 8, 12, 3}
-	nn, _ := neuralnetwork.NewNeuralNetwork(sizes, neuralnetwork.NewRPropInitializer(neuralnetwork.RPropConfig{
-		NuPlus:   1.2,
-		NuMinus:  0.5,
-		DeltaMax: 50.0,
-		DeltaMin: 0.000001,
-	}))
-
-	rc := &remotecontrol.RemoteControl{}
-	nn.SetStateWatcher(rc)
+	s.StartServer()
+	p := genetic.NewPopulation(s, mutagen.NewDummyMutagen(50), 40, []int{16, 12, 12, 4})
+	p.NaturalSelection(200)
+	// s.Run()
+
+	// sizes := []int{13, 8, 12, 3}
+	// nn, _ := neuralnetwork.NewNeuralNetwork(sizes, neuralnetwork.NewRPropInitializer(neuralnetwork.RPropConfig{
+	// 	NuPlus:   1.2,
+	// 	NuMinus:  0.5,
+	// 	DeltaMax: 50.0,
+	// 	DeltaMin: 0.000001,
+	// }))
+
+	// rc := &remotecontrol.RemoteControl{}
+	// nn.SetStateWatcher(rc)
+	// rc.Run()
+
 	// inFile, err := os.Open("./networkstate")
 	// if err != nil {
 	// 	log.Fatal(err)
@@ -59,5 +64,5 @@ func main() {
 	// }
 
 	// fmt.Printf("Fail count: %v\n\n", failCount)
-	rc.Run()
+
 }

+ 4 - 2
neuralnetwork/neuralnetwork/neuralnetwork.go

@@ -136,8 +136,10 @@ func NewNeuralNetwork(sizes []int, gradientDescentInitializer GradientDescentIni
 	for l := 1; l < nn.LayerCount; l++ {
 		nn.Biases[l] = generateRandomDense(nn.Sizes[l], 1)
 		nn.Weights[l] = generateRandomDense(nn.Sizes[l], nn.Sizes[l-1])
-		nn.BGradient[l] = nn.gradientDescentInitializer(nn, l, BiasGradient)
-		nn.WGradient[l] = nn.gradientDescentInitializer(nn, l, WeightGradient)
+		if nn.gradientDescentInitializer != nil {
+			nn.BGradient[l] = nn.gradientDescentInitializer(nn, l, BiasGradient)
+			nn.WGradient[l] = nn.gradientDescentInitializer(nn, l, WeightGradient)
+		}
 	}
 	return
 }

+ 185 - 1
neuralnetwork/snakesimulator/snakesimulator.go

@@ -2,18 +2,24 @@ package snakesimulator
 
 import (
 	fmt "fmt"
+	math "math"
 	"math/rand"
 	"net"
 	"time"
 
+	"gonum.org/v1/gonum/mat"
+
+	genetic "../genetic"
 	grpc "google.golang.org/grpc"
 )
 
 type SnakeSimulator struct {
 	field            *Field
 	snake            *Snake
+	stats            *Stats
 	fieldUpdateQueue chan bool
 	snakeUpdateQueue chan bool
+	statsUpdateQueue chan bool
 }
 
 func NewSnakeSimulator() (s *SnakeSimulator) {
@@ -30,13 +36,176 @@ func NewSnakeSimulator() (s *SnakeSimulator) {
 				&Point{X: 22, Y: 20},
 			},
 		},
+		stats:            &Stats{},
 		fieldUpdateQueue: make(chan bool, 2),
 		snakeUpdateQueue: make(chan bool, 2),
+		statsUpdateQueue: make(chan bool, 2),
 	}
 	return
 }
 
-func (s *SnakeSimulator) Run() {
+func (s *SnakeSimulator) Verify(population *genetic.Population) (results []*genetic.IndividalResult) {
+	s.stats.Generation++
+	s.statsUpdateQueue <- true
+
+	results = make([]*genetic.IndividalResult, len(population.Networks))
+	for index, inidividual := range population.Networks {
+		s.stats.Individual = uint32(index)
+		s.statsUpdateQueue <- true
+
+		s.field.GenerateNextFood()
+		s.snake = &Snake{
+			Points: []*Point{
+				&Point{X: 20, Y: 20},
+				&Point{X: 21, Y: 20},
+				&Point{X: 22, Y: 20},
+			},
+		}
+
+		i := 0
+		for i < 300 {
+			s.stats.Move = uint32(i)
+			s.statsUpdateQueue <- true
+
+			i++
+			s.snakeUpdateQueue <- true
+			direction, _ := inidividual.Predict(mat.NewDense(inidividual.Sizes[0], 1, s.GetHeadState()))
+			newHead := s.snake.NewHead(Direction(direction + 1))
+			if newHead.X == s.field.Food.X && newHead.Y == s.field.Food.Y {
+				s.snake.Feed(newHead)
+				s.field.GenerateNextFood()
+				s.fieldUpdateQueue <- true
+			} else if newHead.X >= s.field.Width || newHead.Y >= s.field.Height {
+				fmt.Printf("Game over\n")
+				// time.Sleep(1000 * time.Millisecond)
+				break
+			} else if selfCollisionIndex := s.snake.SelfCollision(newHead); selfCollisionIndex > 0 {
+				if selfCollisionIndex == 1 {
+					fmt.Printf("Step backward, skip\n")
+					continue
+				}
+				fmt.Printf("Game over self collision\n")
+				break
+			} else {
+				s.snake.Move(newHead)
+			}
+			time.Sleep(10 * time.Millisecond)
+		}
+
+		results[index] = &genetic.IndividalResult{
+			Result: float64(len(s.snake.Points)) * float64(i),
+			Index:  index,
+		}
+	}
+	return
+}
+
+func (s *SnakeSimulator) GetHeadState() []float64 {
+	headX := int32(s.snake.Points[0].X)
+	headY := int32(s.snake.Points[0].Y)
+	foodX := int32(s.field.Food.X)
+	foodY := int32(s.field.Food.Y)
+	width := int32(s.field.Width)
+	height := int32(s.field.Height)
+	diag := float64(width) * math.Sqrt2
+
+	lWall := headX
+	rWall := width - headX
+	tWall := headY
+	bWall := height - headY
+	lFood := int32(0)
+	rFood := int32(0)
+	tFood := int32(0)
+	bFood := int32(0)
+
+	tlFood := float64(0)
+	trFood := float64(0)
+	blFood := float64(0)
+	brFood := float64(0)
+	tlWall := float64(0)
+	trWall := float64(0)
+	blWall := float64(0)
+	brWall := float64(0)
+
+	if foodX == headX {
+		if foodY > headY {
+			bFood = foodY - headY
+		} else {
+			tFood = headY - foodY
+		}
+	}
+
+	if foodY == headY {
+		if foodX > headX {
+			rFood = foodX - headX
+		} else {
+			lFood = headX - foodX
+		}
+	}
+
+	if lWall > tWall {
+		tlWall = float64(tWall) * math.Sqrt2
+	} else {
+		tlWall = float64(lWall) * math.Sqrt2
+	}
+
+	if rWall > tWall {
+		trWall = float64(tWall) * math.Sqrt2
+	} else {
+		trWall = float64(rWall) * math.Sqrt2
+	}
+
+	if lWall > bWall {
+		blWall = float64(bWall) * math.Sqrt2
+	} else {
+		blWall = float64(lWall) * math.Sqrt2
+	}
+
+	if rWall > bWall {
+		blWall = float64(bWall) * math.Sqrt2
+	} else {
+		brWall = float64(rWall) * math.Sqrt2
+	}
+
+	foodDiagXDiff := math.Abs(float64(foodX - headX))
+	foodDiagYDiff := math.Abs(float64(foodY - headY))
+	if foodDiagXDiff == foodDiagYDiff {
+		if math.Signbit(float64(foodX - headX)) {
+			if math.Signbit(float64(foodY - headY)) {
+				trFood = foodDiagXDiff * math.Sqrt2
+			} else {
+				brFood = foodDiagXDiff * math.Sqrt2
+			}
+		} else {
+			if math.Signbit(float64(foodY - headY)) {
+				tlFood = foodDiagXDiff * math.Sqrt2
+			} else {
+				blFood = foodDiagXDiff * math.Sqrt2
+			}
+		}
+	}
+
+	return []float64{
+		float64(lWall) / float64(width),
+		float64(rWall) / float64(width),
+		float64(tWall) / float64(height),
+		float64(bWall) / float64(height),
+		float64(lFood) / float64(width),
+		float64(rFood) / float64(width),
+		float64(tFood) / float64(height),
+		float64(bFood) / float64(height),
+		float64(tlFood) / diag,
+		float64(trFood) / diag,
+		float64(blFood) / diag,
+		float64(brFood) / diag,
+		float64(tlWall) / diag,
+		float64(trWall) / diag,
+		float64(blWall) / diag,
+		float64(brWall) / diag,
+	}
+}
+
+func (s *SnakeSimulator) StartServer() {
 	go func() {
 		grpcServer := grpc.NewServer()
 		RegisterSnakeSimulatorServer(grpcServer, s)
@@ -50,7 +219,9 @@ func (s *SnakeSimulator) Run() {
 			fmt.Printf("Failed to serve: %v\n", err)
 		}
 	}()
+}
 
+func (s *SnakeSimulator) Run() {
 	s.field.GenerateNextFood()
 	for true {
 		direction := rand.Int31()%4 + 1
@@ -102,3 +273,16 @@ func (s *SnakeSimulator) Snake(_ *None, srv SnakeSimulator_SnakeServer) error {
 		<-s.snakeUpdateQueue
 	}
 }
+
+func (s *SnakeSimulator) Stats(_ *None, srv SnakeSimulator_StatsServer) error {
+	ctx := srv.Context()
+	for {
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		default:
+		}
+		srv.Send(s.stats)
+		<-s.statsUpdateQueue
+	}
+}

+ 142 - 20
neuralnetwork/snakesimulator/snakesimulator.pb.go

@@ -199,6 +199,61 @@ func (m *Field) GetFood() *Point {
 	return nil
 }
 
+type Stats struct {
+	Generation           uint32   `protobuf:"varint,1,opt,name=generation,proto3" json:"generation,omitempty"`
+	Individual           uint32   `protobuf:"varint,2,opt,name=individual,proto3" json:"individual,omitempty"`
+	Move                 uint32   `protobuf:"varint,3,opt,name=move,proto3" json:"move,omitempty"`
+	XXX_NoUnkeyedLiteral struct{} `json:"-"`
+	XXX_unrecognized     []byte   `json:"-"`
+	XXX_sizecache        int32    `json:"-"`
+}
+
+func (m *Stats) Reset()         { *m = Stats{} }
+func (m *Stats) String() string { return proto.CompactTextString(m) }
+func (*Stats) ProtoMessage()    {}
+func (*Stats) Descriptor() ([]byte, []int) {
+	return fileDescriptor_b704e55df18a3970, []int{3}
+}
+
+func (m *Stats) XXX_Unmarshal(b []byte) error {
+	return xxx_messageInfo_Stats.Unmarshal(m, b)
+}
+func (m *Stats) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+	return xxx_messageInfo_Stats.Marshal(b, m, deterministic)
+}
+func (m *Stats) XXX_Merge(src proto.Message) {
+	xxx_messageInfo_Stats.Merge(m, src)
+}
+func (m *Stats) XXX_Size() int {
+	return xxx_messageInfo_Stats.Size(m)
+}
+func (m *Stats) XXX_DiscardUnknown() {
+	xxx_messageInfo_Stats.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Stats proto.InternalMessageInfo
+
+func (m *Stats) GetGeneration() uint32 {
+	if m != nil {
+		return m.Generation
+	}
+	return 0
+}
+
+func (m *Stats) GetIndividual() uint32 {
+	if m != nil {
+		return m.Individual
+	}
+	return 0
+}
+
+func (m *Stats) GetMove() uint32 {
+	if m != nil {
+		return m.Move
+	}
+	return 0
+}
+
 type None struct {
 	XXX_NoUnkeyedLiteral struct{} `json:"-"`
 	XXX_unrecognized     []byte   `json:"-"`
@@ -209,7 +264,7 @@ func (m *None) Reset()         { *m = None{} }
 func (m *None) String() string { return proto.CompactTextString(m) }
 func (*None) ProtoMessage()    {}
 func (*None) Descriptor() ([]byte, []int) {
-	return fileDescriptor_b704e55df18a3970, []int{3}
+	return fileDescriptor_b704e55df18a3970, []int{4}
 }
 
 func (m *None) XXX_Unmarshal(b []byte) error {
@@ -235,31 +290,35 @@ func init() {
 	proto.RegisterType((*Point)(nil), "snakesimulator.Point")
 	proto.RegisterType((*Snake)(nil), "snakesimulator.Snake")
 	proto.RegisterType((*Field)(nil), "snakesimulator.Field")
+	proto.RegisterType((*Stats)(nil), "snakesimulator.Stats")
 	proto.RegisterType((*None)(nil), "snakesimulator.None")
 }
 
 func init() { proto.RegisterFile("snakesimulator.proto", fileDescriptor_b704e55df18a3970) }
 
 var fileDescriptor_b704e55df18a3970 = []byte{
-	// 273 bytes of a gzipped FileDescriptorProto
-	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x91, 0xc1, 0x4b, 0xc3, 0x30,
-	0x14, 0x87, 0x97, 0xb6, 0xa9, 0xee, 0x4d, 0x47, 0x79, 0x54, 0x29, 0x9e, 0x4a, 0xbc, 0x4c, 0xc1,
-	0x21, 0x13, 0xc4, 0x9b, 0x97, 0xe1, 0x49, 0x44, 0x3a, 0x76, 0x77, 0xda, 0xd4, 0x86, 0xcd, 0xa4,
-	0xb4, 0x91, 0x6d, 0x77, 0xff, 0x70, 0x79, 0x69, 0x3d, 0x38, 0x11, 0x76, 0xcb, 0xef, 0xf1, 0xbe,
-	0x7c, 0xef, 0x25, 0x10, 0x37, 0x7a, 0xb1, 0x94, 0x8d, 0xfa, 0xf8, 0x5c, 0x2d, 0xac, 0xa9, 0xc7,
-	0x55, 0x6d, 0xac, 0xc1, 0xe1, 0xef, 0xaa, 0x38, 0x07, 0xfe, 0x6c, 0x94, 0xb6, 0x78, 0x04, 0x6c,
-	0x93, 0xb0, 0x94, 0x8d, 0x8e, 0x33, 0xb6, 0xa1, 0xb4, 0x4d, 0xbc, 0x36, 0x6d, 0xc5, 0x2d, 0xf0,
-	0x19, 0x61, 0x78, 0x05, 0x61, 0x45, 0xdd, 0x4d, 0xc2, 0x52, 0x7f, 0x34, 0x98, 0x9c, 0x8c, 0x77,
-	0x24, 0xee, 0xae, 0xac, 0x6b, 0x12, 0x2f, 0xc0, 0x1f, 0x94, 0x5c, 0xe5, 0x18, 0x03, 0x5f, 0xab,
-	0xdc, 0x96, 0x9d, 0xa0, 0x0d, 0x78, 0x0a, 0x61, 0x29, 0xd5, 0x7b, 0x69, 0x3b, 0x53, 0x97, 0xf0,
-	0x02, 0x82, 0xc2, 0x98, 0x3c, 0xf1, 0x53, 0xf6, 0xbf, 0xc3, 0xb5, 0x88, 0x10, 0x82, 0x27, 0xa3,
-	0xe5, 0xe5, 0x3d, 0xf4, 0xa7, 0xaa, 0x96, 0x6f, 0x56, 0x19, 0x8d, 0x03, 0x38, 0x98, 0xeb, 0xa5,
-	0x36, 0x6b, 0x1d, 0xf5, 0x30, 0x04, 0x6f, 0x5e, 0x45, 0x0c, 0x0f, 0x21, 0x98, 0x52, 0xc5, 0xa3,
-	0xd3, 0xa3, 0x2c, 0x6c, 0xe4, 0x63, 0x1f, 0x78, 0x46, 0xc6, 0x28, 0x98, 0x7c, 0x31, 0x18, 0xba,
-	0x1d, 0x67, 0x3f, 0x1e, 0xbc, 0x03, 0xee, 0xcc, 0x18, 0xef, 0x4e, 0x40, 0xca, 0xb3, 0x3f, 0x73,
-	0x39, 0x5c, 0xf4, 0xae, 0x19, 0x91, 0x45, 0xbb, 0xf7, 0x7e, 0xa4, 0x7b, 0x24, 0x22, 0x5f, 0x43,
-	0xf7, 0x4b, 0x37, 0xdf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x15, 0x7b, 0x53, 0xb6, 0xbd, 0x01, 0x00,
-	0x00,
+	// 329 bytes of a gzipped FileDescriptorProto
+	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcf, 0x4b, 0xfb, 0x30,
+	0x18, 0xc6, 0x97, 0xad, 0xe9, 0xf7, 0xbb, 0x77, 0x6e, 0x94, 0x97, 0x29, 0xc5, 0x83, 0x8c, 0x78,
+	0x99, 0x82, 0x43, 0x26, 0x88, 0x37, 0x2f, 0xc3, 0x93, 0x88, 0x74, 0xec, 0xe4, 0xc5, 0x6a, 0xb3,
+	0x35, 0xac, 0x4b, 0x4a, 0x9b, 0xfd, 0xfa, 0xe7, 0xfc, 0xdb, 0x24, 0x69, 0x06, 0x73, 0x22, 0xec,
+	0x96, 0xf7, 0xe9, 0xf3, 0xe4, 0xc3, 0xf3, 0x36, 0xd0, 0x2d, 0x65, 0x3c, 0xe7, 0xa5, 0x58, 0x2c,
+	0xb3, 0x58, 0xab, 0x62, 0x90, 0x17, 0x4a, 0x2b, 0xec, 0xfc, 0x54, 0xd9, 0x25, 0xd0, 0x57, 0x25,
+	0xa4, 0xc6, 0x13, 0x20, 0x9b, 0x90, 0xf4, 0x48, 0xbf, 0x1d, 0x91, 0x8d, 0x99, 0xb6, 0x61, 0xbd,
+	0x9a, 0xb6, 0xec, 0x1e, 0xe8, 0xd8, 0xc4, 0xf0, 0x06, 0xfc, 0xdc, 0xb8, 0xcb, 0x90, 0xf4, 0x1a,
+	0xfd, 0xd6, 0xf0, 0x74, 0x70, 0x00, 0xb1, 0x77, 0x45, 0xce, 0xc4, 0xde, 0x81, 0x3e, 0x09, 0x9e,
+	0x25, 0xd8, 0x05, 0xba, 0x16, 0x89, 0x4e, 0x1d, 0xa0, 0x1a, 0xf0, 0x0c, 0xfc, 0x94, 0x8b, 0x59,
+	0xaa, 0x1d, 0xc9, 0x4d, 0x78, 0x05, 0xde, 0x54, 0xa9, 0x24, 0x6c, 0xf4, 0xc8, 0xdf, 0x0c, 0x6b,
+	0x61, 0x6f, 0x40, 0xc7, 0x3a, 0xd6, 0x25, 0x5e, 0x00, 0xcc, 0xb8, 0xe4, 0x45, 0xac, 0x85, 0x92,
+	0x0e, 0xb3, 0xa7, 0x98, 0xef, 0x42, 0x26, 0x62, 0x25, 0x92, 0x65, 0x9c, 0x39, 0xde, 0x9e, 0x82,
+	0x08, 0xde, 0x42, 0xad, 0xb8, 0x65, 0xb6, 0x23, 0x7b, 0x66, 0x3e, 0x78, 0x2f, 0x4a, 0xf2, 0xeb,
+	0x47, 0x68, 0x8e, 0x44, 0xc1, 0x3f, 0xed, 0x45, 0x2d, 0xf8, 0x37, 0x91, 0x73, 0xa9, 0xd6, 0x32,
+	0xa8, 0xa1, 0x0f, 0xf5, 0x49, 0x1e, 0x10, 0xfc, 0x0f, 0xde, 0xc8, 0x28, 0x75, 0x73, 0x7a, 0xe6,
+	0x53, 0x1d, 0x34, 0xb0, 0x09, 0x34, 0x32, 0x75, 0x02, 0x6f, 0xf8, 0x45, 0xa0, 0x63, 0x17, 0x38,
+	0xde, 0x95, 0xc0, 0x07, 0xa0, 0xb6, 0x16, 0x76, 0x0f, 0xeb, 0x19, 0xe4, 0xf9, 0xaf, 0xd2, 0x36,
+	0xce, 0x6a, 0xb7, 0xc4, 0x24, 0xa7, 0xd5, 0x52, 0x8f, 0x4b, 0xda, 0x3f, 0xb0, 0x4b, 0x96, 0x76,
+	0x59, 0xc7, 0x32, 0x8d, 0xd9, 0x24, 0x3f, 0x7c, 0xfb, 0x78, 0xee, 0xbe, 0x03, 0x00, 0x00, 0xff,
+	0xff, 0x74, 0x50, 0xa2, 0x2d, 0x54, 0x02, 0x00, 0x00,
 }
 
 // Reference imports to suppress errors if they are not otherwise used.
@@ -276,6 +335,7 @@ const _ = grpc.SupportPackageIsVersion4
 type SnakeSimulatorClient interface {
 	Snake(ctx context.Context, in *None, opts ...grpc.CallOption) (SnakeSimulator_SnakeClient, error)
 	Field(ctx context.Context, in *None, opts ...grpc.CallOption) (SnakeSimulator_FieldClient, error)
+	Stats(ctx context.Context, in *None, opts ...grpc.CallOption) (SnakeSimulator_StatsClient, error)
 }
 
 type snakeSimulatorClient struct {
@@ -350,10 +410,43 @@ func (x *snakeSimulatorFieldClient) Recv() (*Field, error) {
 	return m, nil
 }
 
+func (c *snakeSimulatorClient) Stats(ctx context.Context, in *None, opts ...grpc.CallOption) (SnakeSimulator_StatsClient, error) {
+	stream, err := c.cc.NewStream(ctx, &_SnakeSimulator_serviceDesc.Streams[2], "/snakesimulator.SnakeSimulator/stats", opts...)
+	if err != nil {
+		return nil, err
+	}
+	x := &snakeSimulatorStatsClient{stream}
+	if err := x.ClientStream.SendMsg(in); err != nil {
+		return nil, err
+	}
+	if err := x.ClientStream.CloseSend(); err != nil {
+		return nil, err
+	}
+	return x, nil
+}
+
+type SnakeSimulator_StatsClient interface {
+	Recv() (*Stats, error)
+	grpc.ClientStream
+}
+
+type snakeSimulatorStatsClient struct {
+	grpc.ClientStream
+}
+
+func (x *snakeSimulatorStatsClient) Recv() (*Stats, error) {
+	m := new(Stats)
+	if err := x.ClientStream.RecvMsg(m); err != nil {
+		return nil, err
+	}
+	return m, nil
+}
+
 // SnakeSimulatorServer is the server API for SnakeSimulator service.
 type SnakeSimulatorServer interface {
 	Snake(*None, SnakeSimulator_SnakeServer) error
 	Field(*None, SnakeSimulator_FieldServer) error
+	Stats(*None, SnakeSimulator_StatsServer) error
 }
 
 // UnimplementedSnakeSimulatorServer can be embedded to have forward compatible implementations.
@@ -366,6 +459,9 @@ func (*UnimplementedSnakeSimulatorServer) Snake(req *None, srv SnakeSimulator_Sn
 func (*UnimplementedSnakeSimulatorServer) Field(req *None, srv SnakeSimulator_FieldServer) error {
 	return status.Errorf(codes.Unimplemented, "method Field not implemented")
 }
+func (*UnimplementedSnakeSimulatorServer) Stats(req *None, srv SnakeSimulator_StatsServer) error {
+	return status.Errorf(codes.Unimplemented, "method Stats not implemented")
+}
 
 func RegisterSnakeSimulatorServer(s *grpc.Server, srv SnakeSimulatorServer) {
 	s.RegisterService(&_SnakeSimulator_serviceDesc, srv)
@@ -413,6 +509,27 @@ func (x *snakeSimulatorFieldServer) Send(m *Field) error {
 	return x.ServerStream.SendMsg(m)
 }
 
+func _SnakeSimulator_Stats_Handler(srv interface{}, stream grpc.ServerStream) error {
+	m := new(None)
+	if err := stream.RecvMsg(m); err != nil {
+		return err
+	}
+	return srv.(SnakeSimulatorServer).Stats(m, &snakeSimulatorStatsServer{stream})
+}
+
+type SnakeSimulator_StatsServer interface {
+	Send(*Stats) error
+	grpc.ServerStream
+}
+
+type snakeSimulatorStatsServer struct {
+	grpc.ServerStream
+}
+
+func (x *snakeSimulatorStatsServer) Send(m *Stats) error {
+	return x.ServerStream.SendMsg(m)
+}
+
 var _SnakeSimulator_serviceDesc = grpc.ServiceDesc{
 	ServiceName: "snakesimulator.SnakeSimulator",
 	HandlerType: (*SnakeSimulatorServer)(nil),
@@ -428,6 +545,11 @@ var _SnakeSimulator_serviceDesc = grpc.ServiceDesc{
 			Handler:       _SnakeSimulator_Field_Handler,
 			ServerStreams: true,
 		},
+		{
+			StreamName:    "stats",
+			Handler:       _SnakeSimulator_Stats_Handler,
+			ServerStreams: true,
+		},
 	},
 	Metadata: "snakesimulator.proto",
 }

+ 8 - 1
neuralnetwork/snakesimulator/snakesimulator.proto

@@ -50,10 +50,17 @@ message Field {
     Point food = 3;
 }
 
+message Stats {
+    uint32 generation = 1;
+    uint32 individual = 2;
+    uint32 move = 3;
+}
+
 message None {
 }
 
 service SnakeSimulator {
     rpc snake(None) returns (stream Snake) {}
     rpc field(None) returns (stream Field) {}
-}
+    rpc stats(None) returns (stream Stats) {}
+}

+ 7 - 0
neuralnetwork/snakesimulator/snakesimulatorui/main.cpp

@@ -49,6 +49,7 @@ int main(int argc, char *argv[])
 
     snakesimulator::Snake *snake = new snakesimulator::Snake;
     snakesimulator::Field *field = new snakesimulator::Field;
+    snakesimulator::Stats *stats = new snakesimulator::Stats;
 
     QObject::connect(client.get(), &snakesimulator::SnakeSimulatorClient::fieldUpdated, [field](const snakesimulator::Field & _field){
         *field = _field;
@@ -58,12 +59,18 @@ int main(int argc, char *argv[])
         *snake = _snake;
     });
 
+    QObject::connect(client.get(), &snakesimulator::SnakeSimulatorClient::statsUpdated, [stats](const snakesimulator::Stats & _stats){
+        *stats = _stats;
+    });
+
     client->subscribeFieldUpdates({});
     client->subscribeSnakeUpdates({});
+    client->subscribeStatsUpdates({});
 
     QQmlApplicationEngine engine;
     engine.rootContext()->setContextProperty("field", field);
     engine.rootContext()->setContextProperty("snake", snake);
+    engine.rootContext()->setContextProperty("stats", stats);
     engine.load(QUrl(QStringLiteral("qrc:/main.qml")));
     if (engine.rootObjects().isEmpty())
         return -1;

+ 20 - 0
neuralnetwork/snakesimulator/snakesimulatorui/main.qml

@@ -58,6 +58,26 @@ ApplicationWindow {
         height: tileSize
     }
 
+    Column {
+        anchors {
+            right: parent.right
+            top: parent.top
+            margins: 10
+        }
+        Text {
+            color: "#ddffee"
+            text: "Generation: " + stats.generation
+        }
+        Text {
+            color: "#ddffee"
+            text: "Individual: " + stats.individual
+        }
+        Text {
+            color: "#ddffee"
+            text: "Move: " + stats.move
+        }
+    }
+
     Connections {
         target: field
         onWidthChanged: {