package neuralnetworkbase

import (
	teach "../teach"
	mat "gonum.org/v1/gonum/mat"
)

// NeuralNetwork is simple neural network implementation
//
// Matrix: A
// Description: A is set of calculated neuron activations after sigmoid correction
// Format:    0            n          N
//         ⎡A[0] ⎤ ... ⎡A[0] ⎤ ... ⎡A[0] ⎤
//         ⎢A[1] ⎥ ... ⎢A[1] ⎥ ... ⎢A[1] ⎥
//         ⎢ ... ⎥ ... ⎢ ... ⎥ ... ⎢ ... ⎥
//         ⎢A[i] ⎥ ... ⎢A[i] ⎥ ... ⎢A[i] ⎥
//         ⎢ ... ⎥ ... ⎢ ... ⎥ ... ⎢ ... ⎥
//         ⎣A[s] ⎦ ... ⎣A[s] ⎦ ... ⎣A[s] ⎦
// Where s = Sizes[n], N = len(Sizes)
//
// Matrix: Z
// Description: Z is set of calculated raw neuron activations
// Format:    0            n          N
//         ⎡Z[0] ⎤ ... ⎡Z[0] ⎤ ... ⎡Z[0] ⎤
//         ⎢Z[1] ⎥ ... ⎢Z[1] ⎥ ... ⎢Z[1] ⎥
//         ⎢ ... ⎥ ... ⎢ ... ⎥ ... ⎢ ... ⎥
//         ⎢Z[i] ⎥ ... ⎢Z[i] ⎥ ... ⎢Z[i] ⎥
//         ⎢ ... ⎥ ... ⎢ ... ⎥ ... ⎢ ... ⎥
//         ⎣Z[s] ⎦ ... ⎣Z[s] ⎦ ... ⎣Z[s] ⎦
// Where s = Sizes[n], N = len(Sizes)
//
// Matrix: Biases
// Description: Biases is set of biases per layer except L0
// Format:
//         ⎡b[0] ⎤
//         ⎢b[1] ⎥
//         ⎢ ... ⎥
//         ⎢b[i] ⎥
//         ⎢ ... ⎥
//         ⎣b[s] ⎦
// Where s = Sizes[n]
//
// Matrix: Weights
// Description: Weights is set of weights per layer except L0
// Format:
//         ⎡w[0,0] ... w[0,j] ... w[0,s']⎤
//         ⎢w[1,0] ... w[1,j] ... w[1,s']⎥
//         ⎢              ...            ⎥
//         ⎢w[i,0] ... w[i,j] ... w[i,s']⎥
//         ⎢              ...            ⎥
//         ⎣w[s,0] ... w[s,j] ... w[s,s']⎦
// Where s = Sizes[n], s' = Sizes[n-1]

type NeuralNetwork struct {
	Count          int
	Sizes          []int
	Biases         []*mat.Dense
	Weights        []*mat.Dense
	A              []*mat.Dense
	Z              []*mat.Dense
	alpha          float64
	trainingCycles int
}

func NewNeuralNetwork(Sizes []int, nu float64, trainingCycles int) (nn *NeuralNetwork) {
	nn = &NeuralNetwork{}
	nn.Sizes = Sizes
	nn.Count = len(Sizes)
	nn.Weights = make([]*mat.Dense, nn.Count)
	nn.Biases = make([]*mat.Dense, nn.Count)
	nn.A = make([]*mat.Dense, nn.Count)
	nn.Z = make([]*mat.Dense, nn.Count)
	nn.alpha = nu / float64(nn.Sizes[0])
	nn.trainingCycles = trainingCycles

	for i := 1; i < nn.Count; i++ {
		nn.Weights[i] = generateRandomDense(nn.Sizes[i], nn.Sizes[i-1])
		nn.Biases[i] = generateRandomDense(nn.Sizes[i], 1)
	}
	return
}

func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
	nn.forward(aIn)
	result := nn.result()
	r, _ := result.Dims()
	max = 0.0
	maxIndex = 0
	for i := 0; i < r; i++ {
		if result.At(i, 0) > max {
			max = result.At(i, 0)
			maxIndex = i
		}
	}
	return
}

func (nn *NeuralNetwork) Teach(teacher teach.Teacher) {
	for i := 0; i < nn.trainingCycles; i++ {
		for teacher.Next() {
			nn.backward(teacher.GetData(), teacher.GetExpect())
		}
	}
}

func (nn *NeuralNetwork) SaveState(filename string) {

}

func (nn *NeuralNetwork) LoadState(filename string) {

}

func (nn *NeuralNetwork) forward(aIn mat.Matrix) {
	nn.A[0] = mat.DenseCopyOf(aIn)

	for i := 1; i < nn.Count; i++ {
		nn.A[i] = mat.NewDense(nn.Sizes[i], 1, nil)
		aSrc := nn.A[i-1]
		aDst := nn.A[i]

		//Each iteration implements formula bellow for neuron activation values
		//A[l]=σ(W[l]*A[l−1]+B[l])

		//W[l]*A[l−1]
		aDst.Mul(nn.Weights[i], aSrc)

		//W[l]*A[l−1]+B[l]
		aDst.Add(aDst, nn.Biases[i])

		//Save raw activation value for back propagation
		nn.Z[i] = mat.DenseCopyOf(aDst)

		//σ(W[l]*A[l−1]+B[l])
		aDst.Apply(applySigmoid, aDst)
	}
}

func (nn *NeuralNetwork) backward(aIn, aOut mat.Matrix) {
	nn.forward(aIn)

	lastLayerNum := nn.Count - 1

	//To calculate new values of weights and biases
	//following formulas are used:
	//W[l] = A[l−1]*δ[l]
	//B[l] = δ[l]

	//For last layer δ value is calculated by following:
	//δ = (A[L]−y)⊙σ'(Z[L])

	//Calculate initial error for last layer L
	//error = A[L]-y
	//Where y is expected activations set
	err := &mat.Dense{}
	err.Sub(nn.result(), aOut)

	//Calculate sigmoids prime σ'(Z[L]) for last layer L
	sigmoidsPrime := &mat.Dense{}
	sigmoidsPrime.Apply(applySigmoidPrime, nn.Z[lastLayerNum])

	//(A[L]−y)⊙σ'(Z[L])
	delta := &mat.Dense{}
	delta.MulElem(err, sigmoidsPrime)

	//B[L] = δ[L]
	biases := mat.DenseCopyOf(delta)

	//W[L] = A[L−1]*δ[L]
	weights := &mat.Dense{}
	weights.Mul(delta, nn.A[lastLayerNum-1].T())

	//Initialize new weights and biases values with last layer values
	newBiases := []*mat.Dense{makeBackGradien(biases, nn.Biases[lastLayerNum], nn.alpha)}
	newWeights := []*mat.Dense{makeBackGradien(weights, nn.Weights[lastLayerNum], nn.alpha)}

	//Save calculated delta value temporary error variable
	err = delta

	//Next layer Weights and Biases are calculated using same formulas:
	//W[l] = A[l−1]*δ[l]
	//B[l] = δ[l]

	//But δ[l] is calculated using different formula:
	//δ[l] = ((Wt[l+1])*δ[l+1])⊙σ'(Z[l])
	//Where Wt[l+1] is transponded matrix of actual Weights from
	//forward step
	for l := nn.Count - 2; l > 0; l-- {
		//Calculate sigmoids prime σ'(Z[l]) for last layer l
		sigmoidsPrime := &mat.Dense{}
		sigmoidsPrime.Apply(applySigmoidPrime, nn.Z[l])

		//(Wt[l+1])*δ[l+1]
		//err bellow is delta from previous step(l+1)
		delta := &mat.Dense{}
		wdelta := &mat.Dense{}
		wdelta.Mul(nn.Weights[l+1].T(), err)

		//Calculate new delta and store it to temporary variable err
		//δ[l] = ((Wt[l+1])*δ[l+1])⊙σ'(Z[l])
		delta.MulElem(wdelta, sigmoidsPrime)
		err = delta

		//B[l] = δ[l]
		biases := mat.DenseCopyOf(delta)

		//W[l] = A[l−1]*δ[l]
		//At this point it's required to give explanation for inaccuracy
		//in the formula

		//Multiplying of activations matrix for layer l-1 and δ[l] is imposible
		//because view of matrices are following:
		//          A[l-1]       δ[l]
		//         ⎡A[0]  ⎤     ⎡δ[0] ⎤
		//         ⎢A[1]  ⎥     ⎢δ[1] ⎥
		//         ⎢ ...  ⎥     ⎢ ... ⎥
		//         ⎢A[i]  ⎥  X  ⎢δ[i] ⎥
		//         ⎢ ...  ⎥     ⎢ ... ⎥
		//         ⎣A[s'] ⎦     ⎣δ[s] ⎦
		//So we need to modify these matrices to apply mutiplications and got Weights matrix
		//of following view:
		//         ⎡w[0,0] ... w[0,j] ... w[0,s']⎤
		//         ⎢w[1,0] ... w[1,j] ... w[1,s']⎥
		//         ⎢              ...            ⎥
		//         ⎢w[i,0] ... w[i,j] ... w[i,s']⎥
		//         ⎢              ...            ⎥
		//         ⎣w[s,0] ... w[s,j] ... w[s,s']⎦
		//So we substitude matrices and transposes A[l-1] to get valid multiplication
		//if following view:
		//           δ[l]               A[l-1]
		//         ⎡δ[0] ⎤ x [A[0] A[1] ... A[i] ... A[s']]
		//         ⎢δ[1] ⎥
		//         ⎢ ... ⎥
		//         ⎢δ[i] ⎥
		//         ⎢ ... ⎥
		//         ⎣δ[s] ⎦
		weights := &mat.Dense{}
		weights.Mul(delta, nn.A[l-1].T())

		//!Prepend! new Biases and Weights
		// Scale down
		newBiases = append([]*mat.Dense{makeBackGradien(biases, nn.Biases[l], nn.alpha)}, newBiases...)
		newWeights = append([]*mat.Dense{makeBackGradien(weights, nn.Weights[l], nn.alpha)}, newWeights...)
	}

	newBiases = append([]*mat.Dense{&mat.Dense{}}, newBiases...)
	newWeights = append([]*mat.Dense{&mat.Dense{}}, newWeights...)

	nn.Biases = newBiases
	nn.Weights = newWeights
}

func (nn *NeuralNetwork) result() *mat.Dense {
	return nn.A[nn.Count-1]
}