package snakesimulator

import (
	context "context"
	fmt "fmt"
	math "math"
	"math/rand"
	"net"
	"sort"
	"sync"
	"time"

	"gonum.org/v1/gonum/mat"

	genetic "../genetic"
	neuralnetwork "../neuralnetwork"
	grpc "google.golang.org/grpc"
)

type SnakeSimulator struct {
	field                *Field
	snake                *Snake
	maxVerificationSteps int
	stats                *Stats

	//GUI interface part
	speed            uint32
	fieldUpdateQueue chan bool
	snakeUpdateQueue chan bool
	statsUpdateQueue chan bool
	speedQueue       chan uint32

	snakeReadMutex sync.Mutex
	fieldReadMutex sync.Mutex
}

// Initializes new snake population with maximum number of verification steps
func NewSnakeSimulator(maxVerificationSteps int) (s *SnakeSimulator) {
	s = &SnakeSimulator{
		field: &Field{
			Food:   &Point{},
			Width:  40,
			Height: 40,
		},
		snake: &Snake{
			Points: []*Point{
				&Point{X: 20, Y: 20},
				&Point{X: 21, Y: 20},
				&Point{X: 22, Y: 20},
			},
		},
		stats:                &Stats{},
		maxVerificationSteps: maxVerificationSteps,
		fieldUpdateQueue:     make(chan bool, 2),
		snakeUpdateQueue:     make(chan bool, 2),
		statsUpdateQueue:     make(chan bool, 2),
		speedQueue:           make(chan uint32, 1),
		speed:                10,
	}
	return
}

// Population test method
// Verifies population and returns unsorted finteses for each individual
func (s *SnakeSimulator) Verify(population *genetic.Population) (fitnesses []*genetic.IndividalFitness) {
	s.stats.Generation++
	s.statsUpdateQueue <- true

	s.field.GenerateNextFood()
	if s.speed > 0 {
		s.fieldUpdateQueue <- true
	}

	fitnesses = make([]*genetic.IndividalFitness, len(population.Networks))
	for index, inidividual := range population.Networks {
		s.stats.Individual = uint32(index)
		s.statsUpdateQueue <- true

		s.runSnake(inidividual, false)
		fitnesses[index] = &genetic.IndividalFitness{
			// Fitness: float64(s.stats.Move), //Uncomment this to decrese food impact to individual selection
			Fitness: float64(s.stats.Move) * float64(len(s.snake.Points)-2),
			Index:   index,
		}
	}

	//This is duplication of crossbreedPopulation functionality to display best snake
	sort.Slice(fitnesses, func(i, j int) bool {
		return fitnesses[i].Fitness > fitnesses[j].Fitness //Descent order best will be on top, worst in the bottom
	})

	//Best snake showtime!
	s.fieldReadMutex.Lock()
	s.field.GenerateNextFood()
	s.fieldReadMutex.Unlock()
	s.fieldUpdateQueue <- true
	prevSpeed := s.speed
	s.speed = 5
	s.runSnake(population.Networks[fitnesses[0].Index], false)
	s.speed = prevSpeed
	return
}

func (s *SnakeSimulator) runSnake(inidividual *neuralnetwork.NeuralNetwork, randomStart bool) {
	s.snakeReadMutex.Lock()
	if randomStart {
		rand.Seed(time.Now().UnixNano())
		s.snake = NewSnake(Direction(rand.Uint32()%4), *s.field)
	} else {
		s.snake = NewSnake(Direction_Left, *s.field)
	}
	s.snakeReadMutex.Unlock()

	s.stats.Move = 0
	for i := 0; i < s.maxVerificationSteps; i++ {
		//Read speed from client and sleep in case if user selected slow preview
		select {
		case newSpeed := <-s.speedQueue:
			fmt.Printf("Apply new speed: %v\n", newSpeed)
			if newSpeed <= 10 && newSpeed >= 0 {
				s.speed = newSpeed
			} else if newSpeed < 0 {
				s.speed = 0
			}
		default:
		}

		if s.speed > 0 {
			time.Sleep(100 / time.Duration(s.speed) * time.Millisecond)
			s.statsUpdateQueue <- true
			s.snakeUpdateQueue <- true
		}

		predictIndex, _ := inidividual.Predict(mat.NewDense(inidividual.Sizes[0], 1, s.getHeadState()))
		direction := Direction(predictIndex + 1)
		newHead := s.snake.NewHead(direction)

		if s.snake.selfCollision(newHead, direction) {
			fmt.Printf("Game over self collision\n")
			break
		} else if wallCollision(newHead, *s.field) {
			break
		} else if foodCollision(newHead, s.field.Food) {
			i = 0
			s.snakeReadMutex.Lock()
			s.snake.Feed(newHead)
			s.snakeReadMutex.Unlock()
			s.fieldReadMutex.Lock()
			s.field.GenerateNextFood()
			s.fieldReadMutex.Unlock()
			if s.speed > 0 {
				s.fieldUpdateQueue <- true
			}
		} else {
			s.snakeReadMutex.Lock()
			s.snake.Move(newHead)
			s.snakeReadMutex.Unlock()
		}
		s.stats.Move++
	}
}

// Produces input activations for neural network
func (s *SnakeSimulator) getHeadState() []float64 {
	// Snake state
	headX := float64(s.snake.Points[0].X)
	headY := float64(s.snake.Points[0].Y)
	tailX := float64(s.snake.Points[len(s.snake.Points)-1].X)
	tailY := float64(s.snake.Points[len(s.snake.Points)-1].Y)

	// Field state
	foodX := float64(s.field.Food.X)
	foodY := float64(s.field.Food.Y)
	width := float64(s.field.Width)
	height := float64(s.field.Height)
	diag := float64(width) * math.Sqrt2 //We assume that field is always square

	// Output activations
	// Distance to walls in 4 directions
	lWall := headX
	rWall := (width - headX)
	tWall := headY
	bWall := (height - headY)

	// Distance to walls in 4 diagonal directions, by default is completely inactive
	tlWall := float64(0)
	trWall := float64(0)
	blWall := float64(0)
	brWall := float64(0)

	// Distance to food in 4 directions
	// By default is size of field that means that there is no activation at all
	lFood := float64(width)
	rFood := float64(width)
	tFood := float64(height)
	bFood := float64(height)

	// Distance to food in 4 diagonal directions
	// By default is size of field diagonal that means that there is no activation
	// at all
	tlFood := float64(diag)
	trFood := float64(diag)
	blFood := float64(diag)
	brFood := float64(diag)

	// Distance to tail in 4 directions
	tTail := float64(0)
	bTail := float64(0)
	lTail := float64(0)
	rTail := float64(0)

	// Distance to tail in 4 diagonal directions
	tlTail := float64(0)
	trTail := float64(0)
	blTail := float64(0)
	brTail := float64(0)

	// Diagonal distance to each wall
	if lWall > tWall {
		tlWall = float64(tWall) * math.Sqrt2
	} else {
		tlWall = float64(lWall) * math.Sqrt2
	}

	if rWall > tWall {
		trWall = float64(tWall) * math.Sqrt2
	} else {
		trWall = float64(rWall) * math.Sqrt2
	}

	if lWall > bWall {
		blWall = float64(bWall) * math.Sqrt2
	} else {
		blWall = float64(lWall) * math.Sqrt2
	}

	if rWall > bWall {
		blWall = float64(bWall) * math.Sqrt2
	} else {
		brWall = float64(rWall) * math.Sqrt2
	}

	// Check if food is on same vertical line with head and
	// choose vertical direction for activation
	if headX == foodX {
		if headY-foodY > 0 {
			tFood = 0
		} else {
			bFood = 0
		}
	}

	// Check if food is on same horizontal line with head and
	// choose horizontal direction for activation
	if headY == foodY {
		if headX-foodX > 0 {
			lFood = 0
		} else {
			rFood = 0
		}
	}

	//Check if food is on diagonal any of 4 ways
	if math.Abs(foodY-headY) == math.Abs(foodX-headX) {
		//Choose diagonal direction to food
		if foodX > headX {
			if foodY > headY {
				trFood = 0
			} else {
				brFood = 0
			}
		} else {
			if foodY > headY {
				tlFood = 0
			} else {
				blFood = 0
			}
		}
	}

	// Check if tail is on same vertical line with head and
	// choose vertical direction for activation
	if headX == tailX {
		if headY-tailY > 0 {
			tTail = headY - tailY
		} else {
			bTail = headY - tailY
		}
	}

	// Check if tail is on same horizontal line with head and
	// choose horizontal direction for activation
	if headY == tailY {
		if headX-tailX > 0 {
			rTail = headX - tailX
		} else {
			lTail = headX - tailX
		}
	}

	//Check if tail is on diagonal any of 4 ways
	if math.Abs(headY-tailY) == math.Abs(headX-tailX) {
		//Choose diagonal direction to tail
		if tailY > headY {
			if tailX > headX {
				trTail = diag
			} else {
				tlTail = diag
			}
		} else {
			if tailX > headX {
				brTail = diag
			} else {
				blTail = diag
			}
		}
	}

	return []float64{
		lWall / width,
		rWall / width,
		tWall / height,
		bWall / height,
		tlWall / diag,
		trWall / diag,
		blWall / diag,
		brWall / diag,
		(1.0 - lFood/width),
		(1.0 - rFood/width),
		(1.0 - tFood/height),
		(1.0 - bFood/height),
		(1.0 - tlFood/diag),
		(1.0 - trFood/diag),
		(1.0 - blFood/diag),
		(1.0 - brFood/diag),
		tTail / height,
		bTail / height,
		lTail / width,
		rTail / width,
		tlTail / diag,
		trTail / diag,
		blTail / diag,
		brTail / diag,
	}
}

// Server part

// Runs gRPC server for GUI
func (s *SnakeSimulator) StartServer() {
	go func() {
		grpcServer := grpc.NewServer()
		RegisterSnakeSimulatorServer(grpcServer, s)
		lis, err := net.Listen("tcp", "localhost:65002")
		if err != nil {
			fmt.Printf("Failed to listen: %v\n", err)
		}

		fmt.Printf("Listen SnakeSimulator localhost:65002\n")
		if err := grpcServer.Serve(lis); err != nil {
			fmt.Printf("Failed to serve: %v\n", err)
		}
	}()
}

// Steaming of Field updates
func (s *SnakeSimulator) Field(_ *None, srv SnakeSimulator_FieldServer) error {
	ctx := srv.Context()
	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}
		s.snakeReadMutex.Lock()
		srv.Send(s.field)
		s.snakeReadMutex.Unlock()
		<-s.fieldUpdateQueue
	}
}

// Steaming of Snake position and length updates
func (s *SnakeSimulator) Snake(_ *None, srv SnakeSimulator_SnakeServer) error {
	ctx := srv.Context()
	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}
		srv.Send(s.snake)
		<-s.snakeUpdateQueue
	}
}

// Steaming of snake simulator statistic
func (s *SnakeSimulator) Stats(_ *None, srv SnakeSimulator_StatsServer) error {
	ctx := srv.Context()
	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}
		s.fieldReadMutex.Lock()
		srv.Send(s.stats)
		s.fieldReadMutex.Unlock()
		<-s.statsUpdateQueue
	}
}

// Setup new speed requested from gRPC GUI client
func (s *SnakeSimulator) SetSpeed(ctx context.Context, speed *Speed) (*None, error) {
	s.speedQueue <- speed.Speed
	return &None{}, nil
}