/*
 * MIT License
 *
 * Copyright (c) 2019 Alexey Edelev <semlanik@gmail.com>
 *
 * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy of this
 * software and associated documentation files (the "Software"), to deal in the Software
 * without restriction, including without limitation the rights to use, copy, modify,
 * merge, publish, distribute, sublicense, and/or sell copies of the Software, and
 * to permit persons to whom the Software is furnished to do so, subject to the following
 * conditions:
 *
 * The above copyright notice and this permission notice shall be included in all copies
 * or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
 * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
 * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
 * DEALINGS IN THE SOFTWARE.
 */

package handwriting

import (
	"bytes"
	context "context"
	fmt "fmt"
	"net"

	earlystop "git.semlanik.org/semlanik/NeuralNetwork/earlystop"
	neuralnetwork "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork"
	gradients "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork/gradients"
	training "git.semlanik.org/semlanik/NeuralNetwork/training"

	mat "gonum.org/v1/gonum/mat"
	grpc "google.golang.org/grpc"
)

type HandwritingService struct {
	nn *neuralnetwork.NeuralNetwork
}

func NewHandwritingService() (hws *HandwritingService) {
	hws = &HandwritingService{}
	hws.nn, _ = neuralnetwork.NewNeuralNetwork([]int{784, 300, 10}, gradients.NewRPropInitializer(gradients.RPropConfig{
		NuPlus:   1.2,
		NuMinus:  0.5,
		DeltaMax: 50.0,
		DeltaMin: 0.000001,
	}))

	hws.nn.SetStateWatcher(hws)
	return
}

func (hws *HandwritingService) Recognize(ctx context.Context, matrix *Matrix) (*Result, error) {
	fmt.Printf("Recognize %v size: %v\n", len(matrix.Data), hws.nn.Sizes[0])
	dense := mat.NewDense(hws.nn.Sizes[0], 1, matrix.Data)
	index, _ := hws.nn.Predict(dense)
	fmt.Printf("Recognition result %v\n", index)
	return &Result{ResultCharacter: uint32(index)}, nil
}

func (hws *HandwritingService) SetNeuralNetworkData(ctx context.Context, nnRaw *NeuralNetworkRaw) (*None, error) {
	fmt.Println("SetNeuralNetworkData")
	r := bytes.NewReader(nnRaw.Data)
	hws.nn.LoadState(r)
	return &None{}, nil
}

func (hws *HandwritingService) GetNeuralNetworkData(context.Context, *None) (*NeuralNetworkRaw, error) {
	nnRaw := &NeuralNetworkRaw{}
	fmt.Println("SetNeuralNetworkData")
	r := bytes.NewReader(nnRaw.Data)
	hws.nn.LoadState(r)
	return nnRaw, nil
}

func (hws *HandwritingService) ReTrain(context.Context, *None) (*None, error) {
	fmt.Println("ReTrain")

	trainer := training.NewMNISTReader("./train-images-idx3-ubyte", "./train-labels-idx1-ubyte", "./t10k-images-idx3-ubyte", "./t10k-labels-idx1-ubyte")
	hws.nn.SetEarlyStop(earlystop.NewSimpleDescentEarlyStop(hws.nn, trainer))

	squareError, failCount, total := hws.nn.Validate(trainer)
	fmt.Printf("Fail count before: %v/%v, error: %v\n\n", failCount, total, squareError)

	hws.nn.Train(trainer, 100)

	hws.nn.SaveStateToFile("./mnistnet.nnd")

	squareError, failCount, total = hws.nn.Validate(trainer)
	fmt.Printf("Fail count after: %v/%v, error: %v\n\n", failCount, total, squareError)

	fmt.Println("ReTrain finished")
	return &None{}, nil
}

func (hws *HandwritingService) Run() {
	grpcServer := grpc.NewServer()
	RegisterHandwritingServer(grpcServer, hws)
	lis, err := net.Listen("tcp", "localhost:65001")
	if err != nil {
		fmt.Printf("Failed to listen: %v\n", err)
	}

	fmt.Printf("Listen localhost:65001\n")
	if err := grpcServer.Serve(lis); err != nil {
		fmt.Printf("Failed to serve: %v\n", err)
	}
}

func (hws *HandwritingService) Init(nn *neuralnetwork.NeuralNetwork) {

}

func (hws *HandwritingService) UpdateState(int) {

}

func (hws *HandwritingService) UpdateActivations(int, *mat.Dense) {

}

func (hws *HandwritingService) UpdateBiases(int, *mat.Dense) {

}

func (hws *HandwritingService) UpdateWeights(int, *mat.Dense) {

}

func (hws *HandwritingService) UpdateTraining(t int, epocs int, samplesProcced int, totalSamplesCount int) {
	fmt.Printf("Training progress: Epoc: %v/%v\n", t, epocs)
}

func (hws *HandwritingService) UpdateValidation(validatorCount int, failCount int) {

}

func (hws *HandwritingService) GetSubscriptionFeatures() (features neuralnetwork.SubscriptionFeatures) {
	features = 0
	features.Set(neuralnetwork.TrainingSubscription)
	features.Set(neuralnetwork.ValidationSubscription)

	return
}

func drawImage(dense *mat.Dense) {
	for i := 0; i < 28; i++ {
		for j := 0; j < 28; j++ {
			val := 0
			if dense.At(i*28+j, 0) > 0 {
				val = 1
			}
			fmt.Printf("%v ", val)
		}
		fmt.Println()
	}
}