ソースを参照

Move remote control implementaion from NeuralNetwork library

- Add extra proto interface to run education samples
Alexey Edelev 4 年 前
コミット
7434b57be1

+ 9 - 0
build.sh

@@ -1,6 +1,15 @@
 export GOPATH=$PWD
 export PATH=$PATH:$PWD/bin
 export GOBIN=$PWD/bin
+export RPC_PATH=$PWD/visualization
+
+go get github.com/golang/protobuf/protoc-gen-go
+go install ./src/github.com/golang/protobuf/protoc-gen-go
+
+mkdir -p $RPC_PATH
+rm -f $RPC_PATH/*.pb.go
+protoc -I$RPC_PATH --go_out=plugins=grpc:$RPC_PATH $RPC_PATH/visualization.proto
 
 go get -v
 go build -o $GOBIN/neuralnetwork
+

+ 2 - 3
main.go

@@ -28,11 +28,10 @@ package main
 import (
 	"git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork"
 	"git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork/gradients"
-	"git.semlanik.org/semlanik/NeuralNetwork/remotecontrol"
 )
 
 func main() {
-	rc := remotecontrol.NewRemoteControl()
+	rc := NewRemoteControl()
 	sizes := []int{13, 8, 12, 3}
 	nn, _ := neuralnetwork.NewNeuralNetwork(sizes, gradients.NewRPropInitializer(gradients.RPropConfig{
 		NuPlus:   1.2,
@@ -42,7 +41,7 @@ func main() {
 	}))
 
 	nn.SetStateWatcher(rc)
-	rc.Run()
+	rc.RunServices()
 
 	// inFile, err := os.Open("./networkstate")
 	// if err != nil {

+ 3 - 3
neuralnetworkui/CMakeLists.txt

@@ -16,12 +16,12 @@ if(Qt5_POSITION_INDEPENDENT_CODE)
 endif()
 
 file(GLOB PROTO_FILES ABSOLUTE "${CMAKE_CURRENT_SOURCE_DIR}/../src/git.semlanik.org/semlanik/NeuralNetwork/remotecontrol/remotecontrol.proto")
-
-message("PROTO_FILES: ${PROTO_FILES}")
+file(GLOB VISUALIZATION_PROTO "${CMAKE_CURRENT_SOURCE_DIR}/../visualization/visualization.proto")
+message("PROTO_FILES: ${PROTO_FILES} ${VISUALIZATION_PROTO}")
 
 set(CMAKE_AUTOMOC ON)
 set(CMAKE_AUTORCC ON)
 
 add_executable(NeuralNetworkUi main.cpp qml.qrc valueindicator.cpp visualizermodel.cpp dense.cpp layertrigger.cpp)
-generate_qtprotobuf(TARGET ${TARGET} PROTO_FILES ${PROTO_FILES} QML TRUE)
+generate_qtprotobuf(TARGET ${TARGET} PROTO_FILES ${PROTO_FILES} ${VISUALIZATION_PROTO} QML TRUE)
 target_link_libraries(NeuralNetworkUi PRIVATE Qt5::Core Qt5::Gui Qt5::Qml Qt5::Quick QtProtobufProject::QtProtobuf QtProtobufProject::QtGrpc ${QtProtobuf_GENERATED})

+ 7 - 3
neuralnetworkui/main.cpp

@@ -31,6 +31,7 @@
 
 #include "qtprotobuf_global.qpb.h"
 #include "remotecontrol_grpc.qpb.h"
+#include "visualization_grpc.qpb.h"
 
 #include "qgrpchttp2channel.h"
 #include <QGrpcInsecureCredentials>
@@ -47,10 +48,13 @@ int main(int argc, char *argv[])
     qmlRegisterUncreatableType<ValueIndicator>("NeuralNetworkUi", 0, 1, "ValueIndicator", "");
     qmlRegisterUncreatableType<LayerTrigger>("NeuralNetworkUi", 0, 1, "LayerTrigger", "");
     std::shared_ptr<remotecontrol::RemoteControlClient> client(new remotecontrol::RemoteControlClient);
-    auto chan = std::shared_ptr<QtProtobuf::QGrpcHttp2Channel>(new QtProtobuf::QGrpcHttp2Channel(QUrl("http://localhost:65001"), QtProtobuf::QGrpcInsecureCallCredentials()|QtProtobuf::QGrpcInsecureChannelCredentials()));
-    client->attachChannel(chan);
+    std::shared_ptr<visualization::VisualizationClient> vClient(new visualization::VisualizationClient);
+    client->attachChannel(std::shared_ptr<QtProtobuf::QGrpcHttp2Channel>(new QtProtobuf::QGrpcHttp2Channel(QUrl("http://localhost:65001"),
+                                                                                                           QtProtobuf::QGrpcInsecureCallCredentials()|QtProtobuf::QGrpcInsecureChannelCredentials())));
+    vClient->attachChannel(std::shared_ptr<QtProtobuf::QGrpcHttp2Channel>(new QtProtobuf::QGrpcHttp2Channel(QUrl("http://localhost:65002"),
+                                                                                                            QtProtobuf::QGrpcInsecureCallCredentials()|QtProtobuf::QGrpcInsecureChannelCredentials())));
 
-    std::unique_ptr<VisualizerModel> visualizerModel(new VisualizerModel(client));
+    std::unique_ptr<VisualizerModel> visualizerModel(new VisualizerModel(client, vClient));
 
     QQmlApplicationEngine engine;
     engine.rootContext()->setContextProperty("visualizerModel", visualizerModel.get());

+ 6 - 3
neuralnetworkui/visualizermodel.cpp

@@ -31,12 +31,16 @@
 #include <QQmlEngine>
 
 #include "remotecontrol.qpb.h"
+#include "visualization.qpb.h"
 
 using namespace remotecontrol;
 using namespace QtProtobuf;
 
-VisualizerModel::VisualizerModel(std::shared_ptr<RemoteControlClient> &client, QObject *parent) : QObject(parent)
+VisualizerModel::VisualizerModel(std::shared_ptr<RemoteControlClient> &client,
+                                 std::shared_ptr<visualization::VisualizationClient> &visualizationClient,
+                                 QObject *parent) : QObject(parent)
   , m_client(client)
+  , m_visualizationClient(visualizationClient)
   , m_networkState(new NetworkState{NetworkState::None})
 {
     m_client->getConfiguration({}, this, [this](QGrpcAsyncReply *reply) {
@@ -122,8 +126,7 @@ LayerTrigger *VisualizerModel::weightTrigger(int layer)
     return trigger;
 }
 
-
 void VisualizerModel::start()
 {
-    m_client->dummyStart({});
+    m_visualizationClient->Run({});
 }

+ 5 - 1
neuralnetworkui/visualizermodel.h

@@ -30,6 +30,7 @@
 
 #include "remotecontrol.qpb.h"
 #include "remotecontrol_grpc.qpb.h"
+#include "visualization_grpc.qpb.h"
 
 #include "valueindicator.h"
 #include "abstractdense.h"
@@ -52,7 +53,9 @@ class VisualizerModel : public QObject
     Q_PROPERTY(QList<int> sizes READ sizes NOTIFY sizesChanged)
     Q_PROPERTY(remotecontrol::NetworkState* networkState READ networkState CONSTANT)
 public:
-    explicit VisualizerModel(std::shared_ptr<remotecontrol::RemoteControlClient> &client, QObject *parent = nullptr);
+    explicit VisualizerModel(std::shared_ptr<remotecontrol::RemoteControlClient> &client,
+                             std::shared_ptr<visualization::VisualizationClient> &visualizationClient,
+                             QObject *parent = nullptr);
 
     QList<int> sizes() {
         return m_networkConfig.sizes();
@@ -74,6 +77,7 @@ signals:
 
 private:
     std::shared_ptr<remotecontrol::RemoteControlClient> &m_client;
+    std::shared_ptr<visualization::VisualizationClient> &m_visualizationClient;
     remotecontrol::Configuration m_networkConfig;
     QList<NetworkLayerState*> m_layers;
     QPointer<remotecontrol::NetworkState> m_networkState;

+ 246 - 0
remotecontrol.go

@@ -0,0 +1,246 @@
+/*
+ * MIT License
+ *
+ * Copyright (c) 2019 Alexey Edelev <semlanik@gmail.com>
+ *
+ * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and
+ * to permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
+ * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+ * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ * DEALINGS IN THE SOFTWARE.
+ */
+
+package main
+
+import (
+	context "context"
+	fmt "fmt"
+	"log"
+	"net"
+	"sync"
+	"time"
+
+	"./visualization"
+	neuralnetwork "git.semlanik.org/semlanik/NeuralNetwork/neuralnetwork"
+	remotecontrol "git.semlanik.org/semlanik/NeuralNetwork/remotecontrol"
+	"gonum.org/v1/gonum/mat"
+	grpc "google.golang.org/grpc"
+
+	training "git.semlanik.org/semlanik/NeuralNetwork/training"
+)
+
+type RemoteControl struct {
+	nn               *neuralnetwork.NeuralNetwork
+	activationsQueue chan *remotecontrol.LayerMatrix
+	biasesQueue      chan *remotecontrol.LayerMatrix
+	weightsQueue     chan *remotecontrol.LayerMatrix
+	stateQueue       chan int
+	mutex            sync.Mutex
+	config           *remotecontrol.Configuration
+}
+
+func NewRemoteControl() (rw *RemoteControl) {
+	rw = &RemoteControl{}
+	rw.activationsQueue = make(chan *remotecontrol.LayerMatrix, 5)
+	rw.biasesQueue = make(chan *remotecontrol.LayerMatrix, 5)
+	rw.weightsQueue = make(chan *remotecontrol.LayerMatrix, 5)
+	rw.stateQueue = make(chan int, 2)
+	rw.config = &remotecontrol.Configuration{}
+	return
+}
+
+func (rw *RemoteControl) Init(nn *neuralnetwork.NeuralNetwork) {
+	rw.nn = nn
+	for _, size := range rw.nn.Sizes {
+		rw.config.Sizes = append(rw.config.Sizes, int32(size))
+	}
+}
+
+func (rw *RemoteControl) UpdateActivations(l int, a *mat.Dense) {
+	matrix := NewLayerMatrix(l, a, remotecontrol.LayerMatrix_Activations)
+	select {
+	case rw.activationsQueue <- matrix:
+	default:
+	}
+}
+
+func (rw *RemoteControl) UpdateBiases(l int, biases *mat.Dense) {
+	matrix := NewLayerMatrix(l, biases, remotecontrol.LayerMatrix_Biases)
+	select {
+	case rw.biasesQueue <- matrix:
+	default:
+	}
+}
+
+func (rw *RemoteControl) UpdateWeights(l int, weights *mat.Dense) {
+	matrix := NewLayerMatrix(l, weights, remotecontrol.LayerMatrix_Weights)
+	select {
+	case rw.weightsQueue <- matrix:
+	default:
+	}
+}
+
+func (rw *RemoteControl) UpdateState(state int) {
+	select {
+	case rw.stateQueue <- state:
+	default:
+	}
+}
+
+func NewLayerMatrix(l int, dense *mat.Dense, contentType remotecontrol.LayerMatrix_ContentType) (matrix *remotecontrol.LayerMatrix) {
+	buffer, err := dense.MarshalBinary()
+	if err != nil {
+		log.Fatalln("Invalid dense is provided for remote control")
+	}
+
+	matrix = &remotecontrol.LayerMatrix{
+		Matrix: &remotecontrol.Matrix{
+			Matrix: buffer,
+		},
+		Layer:       int32(l),
+		ContentType: contentType,
+	}
+
+	return
+}
+
+func (rw *RemoteControl) GetConfiguration(context.Context, *remotecontrol.None) (*remotecontrol.Configuration, error) {
+	return rw.config, nil
+}
+
+func (rw *RemoteControl) Activations(_ *remotecontrol.None, srv remotecontrol.RemoteControl_ActivationsServer) error {
+	ctx := srv.Context()
+	for {
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		default:
+		}
+		msg := <-rw.activationsQueue
+		srv.Send(msg)
+	}
+}
+
+func (rw *RemoteControl) Biases(_ *remotecontrol.None, srv remotecontrol.RemoteControl_BiasesServer) error {
+	ctx := srv.Context()
+	for {
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		default:
+		}
+		msg := <-rw.biasesQueue
+		srv.Send(msg)
+	}
+}
+
+func (rw *RemoteControl) Weights(_ *remotecontrol.None, srv remotecontrol.RemoteControl_WeightsServer) error {
+	ctx := srv.Context()
+	for {
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		default:
+		}
+		msg := <-rw.weightsQueue
+		srv.Send(msg)
+	}
+}
+
+func (rw *RemoteControl) State(_ *remotecontrol.None, srv remotecontrol.RemoteControl_StateServer) error {
+	ctx := srv.Context()
+	for {
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		default:
+		}
+		state := <-rw.stateQueue
+		msg := &remotecontrol.NetworkState{
+			State: remotecontrol.NetworkState_State(state),
+		}
+		srv.Send(msg)
+	}
+}
+
+func (rw *RemoteControl) Run(context.Context, *visualization.None) (*visualization.None, error) {
+	go func() {
+		rw.mutex.Lock()
+		defer rw.mutex.Unlock()
+		// trainer := training.NewMNISTReader("./minst.data", "./mnist.labels")
+		trainer := training.NewTextDataReader("wine.data", 5)
+		rw.nn.Train(trainer, 500)
+
+		// for i := 0; i < nn.Count; i++ {
+		// 	if i > 0 {
+		// 		fmt.Printf("Weights after:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
+		// 		fmt.Printf("Biases after:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
+		// 		fmt.Printf("Z after:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
+		// 	}
+		// 	fmt.Printf("A after:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
+		// }
+
+		rw.nn.SaveStateToFile("./neuralnetworkdata.nnd")
+
+		rw.UpdateState(neuralnetwork.StateLearning)
+		defer rw.UpdateState(neuralnetwork.StateIdle)
+		failCount := 0
+		for i := 0; i < trainer.ValidatorCount(); i++ {
+			dataSet, expect := trainer.GetValidator(i)
+			index, _ := rw.nn.Predict(dataSet)
+			//TODO: remove this if not used for visualization
+			time.Sleep(400 * time.Millisecond)
+			if expect.At(index, 0) != 1.0 {
+				failCount++
+				// fmt.Printf("Fail: %v, %v\n\n", trainer.ValidationIndex(), expect.At(index, 0))
+			}
+		}
+
+		fmt.Printf("Fail count: %v\n\n", failCount)
+		failCount = 0
+		rw.UpdateState(neuralnetwork.StateIdle)
+	}()
+
+	return &visualization.None{}, nil
+}
+
+func (rw *RemoteControl) RunServices() {
+	go func() {
+		grpcServer := grpc.NewServer()
+		remotecontrol.RegisterRemoteControlServer(grpcServer, rw)
+		lis, err := net.Listen("tcp", "localhost:65001")
+		if err != nil {
+			fmt.Printf("Failed to listen: %v\n", err)
+		}
+
+		fmt.Printf("Listen localhost:65001\n")
+		if err := grpcServer.Serve(lis); err != nil {
+			fmt.Printf("Failed to serve: %v\n", err)
+		}
+	}()
+
+	grpcServer := grpc.NewServer()
+	visualization.RegisterVisualizationServer(grpcServer, rw)
+	lis, err := net.Listen("tcp", "localhost:65002")
+	if err != nil {
+		fmt.Printf("Failed to listen: %v\n", err)
+	}
+
+	fmt.Printf("Listen localhost:65002\n")
+	if err := grpcServer.Serve(lis); err != nil {
+		fmt.Printf("Failed to serve: %v\n", err)
+	}
+}

+ 35 - 0
visualization/visualization.proto

@@ -0,0 +1,35 @@
+/*
+ * MIT License
+ *
+ * Copyright (c) 2019 Alexey Edelev <semlanik@gmail.com>
+ *
+ * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and
+ * to permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
+ * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+ * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ * DEALINGS IN THE SOFTWARE.
+ */
+
+ syntax="proto3";
+
+package visualization;
+
+message None {
+}
+
+service Visualization {
+    rpc Run(None) returns (None) {}
+}