|
@@ -0,0 +1,246 @@
|
|
|
|
+/*
|
|
|
|
+ * MIT License
|
|
|
|
+ *
|
|
|
|
+ * Copyright (c) 2019 Alexey Edelev <semlanik@gmail.com>
|
|
|
|
+ *
|
|
|
|
+ * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork
|
|
|
|
+ *
|
|
|
|
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
|
|
|
|
+ * software and associated documentation files (the "Software"), to deal in the Software
|
|
|
|
+ * without restriction, including without limitation the rights to use, copy, modify,
|
|
|
|
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and
|
|
|
|
+ * to permit persons to whom the Software is furnished to do so, subject to the following
|
|
|
|
+ * conditions:
|
|
|
|
+ *
|
|
|
|
+ * The above copyright notice and this permission notice shall be included in all copies
|
|
|
|
+ * or substantial portions of the Software.
|
|
|
|
+ *
|
|
|
|
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
|
|
|
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
|
|
|
+ * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
|
|
|
|
+ * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
|
|
|
+ * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
|
|
|
+ * DEALINGS IN THE SOFTWARE.
|
|
|
|
+ */
|
|
|
|
+
|
|
|
|
+package main
|
|
|
|
+
|
|
|
|
+import (
|
|
|
|
+ context "context"
|
|
|
|
+ fmt "fmt"
|
|
|
|
+ "log"
|
|
|
|
+ "net"
|
|
|
|
+ "sync"
|
|
|
|
+ "time"
|
|
|
|
+
|
|
|
|
+ "./visualization"
|
|
|
|
+ neuralnetwork "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork"
|
|
|
|
+ remotecontrol "git.semlanik.org/semlanik/NeuralNetwork/remotecontrol"
|
|
|
|
+ "gonum.org/v1/gonum/mat"
|
|
|
|
+ grpc "google.golang.org/grpc"
|
|
|
|
+
|
|
|
|
+ training "git.semlanik.org/semlanik/NeuralNetwork/training"
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+type RemoteControl struct {
|
|
|
|
+ nn *neuralnetwork.NeuralNetwork
|
|
|
|
+ activationsQueue chan *remotecontrol.LayerMatrix
|
|
|
|
+ biasesQueue chan *remotecontrol.LayerMatrix
|
|
|
|
+ weightsQueue chan *remotecontrol.LayerMatrix
|
|
|
|
+ stateQueue chan int
|
|
|
|
+ mutex sync.Mutex
|
|
|
|
+ config *remotecontrol.Configuration
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func NewRemoteControl() (rw *RemoteControl) {
|
|
|
|
+ rw = &RemoteControl{}
|
|
|
|
+ rw.activationsQueue = make(chan *remotecontrol.LayerMatrix, 5)
|
|
|
|
+ rw.biasesQueue = make(chan *remotecontrol.LayerMatrix, 5)
|
|
|
|
+ rw.weightsQueue = make(chan *remotecontrol.LayerMatrix, 5)
|
|
|
|
+ rw.stateQueue = make(chan int, 2)
|
|
|
|
+ rw.config = &remotecontrol.Configuration{}
|
|
|
|
+ return
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (rw *RemoteControl) Init(nn *neuralnetwork.NeuralNetwork) {
|
|
|
|
+ rw.nn = nn
|
|
|
|
+ for _, size := range rw.nn.Sizes {
|
|
|
|
+ rw.config.Sizes = append(rw.config.Sizes, int32(size))
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (rw *RemoteControl) UpdateActivations(l int, a *mat.Dense) {
|
|
|
|
+ matrix := NewLayerMatrix(l, a, remotecontrol.LayerMatrix_Activations)
|
|
|
|
+ select {
|
|
|
|
+ case rw.activationsQueue <- matrix:
|
|
|
|
+ default:
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (rw *RemoteControl) UpdateBiases(l int, biases *mat.Dense) {
|
|
|
|
+ matrix := NewLayerMatrix(l, biases, remotecontrol.LayerMatrix_Biases)
|
|
|
|
+ select {
|
|
|
|
+ case rw.biasesQueue <- matrix:
|
|
|
|
+ default:
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (rw *RemoteControl) UpdateWeights(l int, weights *mat.Dense) {
|
|
|
|
+ matrix := NewLayerMatrix(l, weights, remotecontrol.LayerMatrix_Weights)
|
|
|
|
+ select {
|
|
|
|
+ case rw.weightsQueue <- matrix:
|
|
|
|
+ default:
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (rw *RemoteControl) UpdateState(state int) {
|
|
|
|
+ select {
|
|
|
|
+ case rw.stateQueue <- state:
|
|
|
|
+ default:
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func NewLayerMatrix(l int, dense *mat.Dense, contentType remotecontrol.LayerMatrix_ContentType) (matrix *remotecontrol.LayerMatrix) {
|
|
|
|
+ buffer, err := dense.MarshalBinary()
|
|
|
|
+ if err != nil {
|
|
|
|
+ log.Fatalln("Invalid dense is provided for remote control")
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ matrix = &remotecontrol.LayerMatrix{
|
|
|
|
+ Matrix: &remotecontrol.Matrix{
|
|
|
|
+ Matrix: buffer,
|
|
|
|
+ },
|
|
|
|
+ Layer: int32(l),
|
|
|
|
+ ContentType: contentType,
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (rw *RemoteControl) GetConfiguration(context.Context, *remotecontrol.None) (*remotecontrol.Configuration, error) {
|
|
|
|
+ return rw.config, nil
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (rw *RemoteControl) Activations(_ *remotecontrol.None, srv remotecontrol.RemoteControl_ActivationsServer) error {
|
|
|
|
+ ctx := srv.Context()
|
|
|
|
+ for {
|
|
|
|
+ select {
|
|
|
|
+ case <-ctx.Done():
|
|
|
|
+ return ctx.Err()
|
|
|
|
+ default:
|
|
|
|
+ }
|
|
|
|
+ msg := <-rw.activationsQueue
|
|
|
|
+ srv.Send(msg)
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (rw *RemoteControl) Biases(_ *remotecontrol.None, srv remotecontrol.RemoteControl_BiasesServer) error {
|
|
|
|
+ ctx := srv.Context()
|
|
|
|
+ for {
|
|
|
|
+ select {
|
|
|
|
+ case <-ctx.Done():
|
|
|
|
+ return ctx.Err()
|
|
|
|
+ default:
|
|
|
|
+ }
|
|
|
|
+ msg := <-rw.biasesQueue
|
|
|
|
+ srv.Send(msg)
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (rw *RemoteControl) Weights(_ *remotecontrol.None, srv remotecontrol.RemoteControl_WeightsServer) error {
|
|
|
|
+ ctx := srv.Context()
|
|
|
|
+ for {
|
|
|
|
+ select {
|
|
|
|
+ case <-ctx.Done():
|
|
|
|
+ return ctx.Err()
|
|
|
|
+ default:
|
|
|
|
+ }
|
|
|
|
+ msg := <-rw.weightsQueue
|
|
|
|
+ srv.Send(msg)
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (rw *RemoteControl) State(_ *remotecontrol.None, srv remotecontrol.RemoteControl_StateServer) error {
|
|
|
|
+ ctx := srv.Context()
|
|
|
|
+ for {
|
|
|
|
+ select {
|
|
|
|
+ case <-ctx.Done():
|
|
|
|
+ return ctx.Err()
|
|
|
|
+ default:
|
|
|
|
+ }
|
|
|
|
+ state := <-rw.stateQueue
|
|
|
|
+ msg := &remotecontrol.NetworkState{
|
|
|
|
+ State: remotecontrol.NetworkState_State(state),
|
|
|
|
+ }
|
|
|
|
+ srv.Send(msg)
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (rw *RemoteControl) Run(context.Context, *visualization.None) (*visualization.None, error) {
|
|
|
|
+ go func() {
|
|
|
|
+ rw.mutex.Lock()
|
|
|
|
+ defer rw.mutex.Unlock()
|
|
|
|
+ // trainer := training.NewMNISTReader("./minst.data", "./mnist.labels")
|
|
|
|
+ trainer := training.NewTextDataReader("wine.data", 5)
|
|
|
|
+ rw.nn.Train(trainer, 500)
|
|
|
|
+
|
|
|
|
+ // for i := 0; i < nn.Count; i++ {
|
|
|
|
+ // if i > 0 {
|
|
|
|
+ // fmt.Printf("Weights after:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
|
|
|
|
+ // fmt.Printf("Biases after:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
|
|
|
|
+ // fmt.Printf("Z after:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
|
|
|
|
+ // }
|
|
|
|
+ // fmt.Printf("A after:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
|
|
|
|
+ // }
|
|
|
|
+
|
|
|
|
+ rw.nn.SaveStateToFile("./neuralnetworkdata.nnd")
|
|
|
|
+
|
|
|
|
+ rw.UpdateState(neuralnetwork.StateLearning)
|
|
|
|
+ defer rw.UpdateState(neuralnetwork.StateIdle)
|
|
|
|
+ failCount := 0
|
|
|
|
+ for i := 0; i < trainer.ValidatorCount(); i++ {
|
|
|
|
+ dataSet, expect := trainer.GetValidator(i)
|
|
|
|
+ index, _ := rw.nn.Predict(dataSet)
|
|
|
|
+ //TODO: remove this if not used for visualization
|
|
|
|
+ time.Sleep(400 * time.Millisecond)
|
|
|
|
+ if expect.At(index, 0) != 1.0 {
|
|
|
|
+ failCount++
|
|
|
|
+ // fmt.Printf("Fail: %v, %v\n\n", trainer.ValidationIndex(), expect.At(index, 0))
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fmt.Printf("Fail count: %v\n\n", failCount)
|
|
|
|
+ failCount = 0
|
|
|
|
+ rw.UpdateState(neuralnetwork.StateIdle)
|
|
|
|
+ }()
|
|
|
|
+
|
|
|
|
+ return &visualization.None{}, nil
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (rw *RemoteControl) RunServices() {
|
|
|
|
+ go func() {
|
|
|
|
+ grpcServer := grpc.NewServer()
|
|
|
|
+ remotecontrol.RegisterRemoteControlServer(grpcServer, rw)
|
|
|
|
+ lis, err := net.Listen("tcp", "localhost:65001")
|
|
|
|
+ if err != nil {
|
|
|
|
+ fmt.Printf("Failed to listen: %v\n", err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fmt.Printf("Listen localhost:65001\n")
|
|
|
|
+ if err := grpcServer.Serve(lis); err != nil {
|
|
|
|
+ fmt.Printf("Failed to serve: %v\n", err)
|
|
|
|
+ }
|
|
|
|
+ }()
|
|
|
|
+
|
|
|
|
+ grpcServer := grpc.NewServer()
|
|
|
|
+ visualization.RegisterVisualizationServer(grpcServer, rw)
|
|
|
|
+ lis, err := net.Listen("tcp", "localhost:65002")
|
|
|
|
+ if err != nil {
|
|
|
|
+ fmt.Printf("Failed to listen: %v\n", err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fmt.Printf("Listen localhost:65002\n")
|
|
|
|
+ if err := grpcServer.Serve(lis); err != nil {
|
|
|
|
+ fmt.Printf("Failed to serve: %v\n", err)
|
|
|
|
+ }
|
|
|
|
+}
|