Browse Source

Add state handling for neural network and RC

Alexey Edelev 5 years ago
parent
commit
46c18d10f1

+ 2 - 0
gui/main.qml

@@ -27,6 +27,7 @@ import QtQuick 2.11
 import QtQuick.Window 2.11
 import QtQuick.Controls 1.4
 import NeuralNetworkUi 0.1
+import remotecontrol 1.0
 
 ApplicationWindow {
     id: root
@@ -157,6 +158,7 @@ ApplicationWindow {
         anchors.right: parent.right
         anchors.top: parent.top
         anchors.margins: 20
+        enabled: visualizerModel.networkState.state === NetworkState.Idle
         onClicked: {
             visualizerModel.start();
         }

+ 2 - 0
gui/visualizermodel.cpp

@@ -36,6 +36,7 @@ using namespace QtProtobuf;
 
 VisualizerModel::VisualizerModel(std::shared_ptr<RemoteControlClient> &client, QObject *parent) : QObject(parent)
   , m_client(client)
+  , m_networkState(new NetworkState)
 {
     m_client->getConfiguration({}, this, [this](QGrpcAsyncReply *reply) {
         qDeleteAll(m_layers);
@@ -85,6 +86,7 @@ VisualizerModel::VisualizerModel(std::shared_ptr<RemoteControlClient> &client, Q
     });
     client->subscribeActivationsUpdates({});
     client->subscribeWeightsUpdates({});
+    client->subscribeStateUpdates({}, m_networkState);
 }
 
 ValueIndicator *VisualizerModel::activation(int layer, int row)

+ 9 - 0
gui/visualizermodel.h

@@ -32,6 +32,7 @@
 #include "valueindicator.h"
 #include "abstractdense.h"
 #include "layertrigger.h"
+#include "networkstate.h"
 
 class ValueIndicator;
 class LayerTrigger;
@@ -48,6 +49,7 @@ class VisualizerModel : public QObject
 {
     Q_OBJECT
     Q_PROPERTY(QList<int> sizes READ sizes NOTIFY sizesChanged)
+    Q_PROPERTY(remotecontrol::NetworkState* networkState READ networkState CONSTANT)
 public:
     explicit VisualizerModel(std::shared_ptr<remotecontrol::RemoteControlClient> &client, QObject *parent = nullptr);
 
@@ -60,6 +62,12 @@ public:
     Q_INVOKABLE LayerTrigger *activationTrigger(int layer);
     Q_INVOKABLE LayerTrigger *weightTrigger(int layer);
     Q_INVOKABLE void start();
+
+    remotecontrol::NetworkState *networkState() const
+    {
+        return m_networkState.data();
+    }
+
 signals:
     void sizesChanged();
 
@@ -67,4 +75,5 @@ private:
     std::shared_ptr<remotecontrol::RemoteControlClient> &m_client;
     remotecontrol::Configuration m_networkConfig;
     QList<NetworkLayerState*> m_layers;
+    QPointer<remotecontrol::NetworkState> m_networkState;
 };

+ 8 - 0
neuralnetwork/neuralnetworkbase/interface.go

@@ -46,8 +46,16 @@ type BatchGradientDescent interface {
 	Gradients() *mat.Dense
 }
 
+const (
+	StateIdle       = 1
+	StateLearning   = 2
+	StateValidation = 3
+	StatePredict    = 4
+)
+
 type StateWatcher interface {
 	Init(nn *NeuralNetwork)
+	UpdateState(state int)
 	UpdateActivations(l int, a *mat.Dense)
 	UpdateBiases(l int, biases *mat.Dense)
 	UpdateWeights(l int, weights *mat.Dense)

+ 12 - 1
neuralnetwork/neuralnetworkbase/neuralnetwork.go

@@ -144,10 +144,17 @@ func NewNeuralNetwork(sizes []int, gradientDescentInitializer GradientDescentIni
 
 func (nn *NeuralNetwork) SetStateWatcher(watcher StateWatcher) {
 	nn.watcher = watcher
-	watcher.Init(nn)
+	if watcher != nil {
+		watcher.Init(nn)
+		watcher.UpdateState(StateIdle)
+	}
 }
 
 func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
+	if nn.watcher != nil {
+		nn.watcher.UpdateState(StatePredict)
+		defer nn.watcher.UpdateState(StateIdle)
+	}
 	r, _ := aIn.Dims()
 	if r != nn.Sizes[0] {
 		fmt.Printf("Invalid rows number of input matrix size: %v\n", r)
@@ -169,6 +176,10 @@ func (nn *NeuralNetwork) Predict(aIn mat.Matrix) (maxIndex int, max float64) {
 }
 
 func (nn *NeuralNetwork) Teach(teacher teach.Teacher, epocs int) {
+	if nn.watcher != nil {
+		nn.watcher.UpdateState(StateLearning)
+		defer nn.watcher.UpdateState(StateIdle)
+	}
 	if _, ok := nn.WGradient[nn.layerCount-1].(OnlineGradientDescent); ok {
 		nn.TeachOnline(teacher, epocs)
 	} else if _, ok := nn.WGradient[nn.layerCount-1].(BatchGradientDescent); ok {

+ 32 - 4
neuralnetwork/remotecontrol/remotecontrol.go

@@ -49,14 +49,16 @@ type RemoteControl struct {
 	activationsQueue chan *LayerMatrix
 	biasesQueue      chan *LayerMatrix
 	weightsQueue     chan *LayerMatrix
+	stateQueue       chan int
 	mutex            sync.Mutex
 }
 
 func (rw *RemoteControl) Init(nn *neuralnetworkbase.NeuralNetwork) {
 	rw.nn = nn
-	rw.activationsQueue = make(chan *LayerMatrix, 10)
-	rw.biasesQueue = make(chan *LayerMatrix, 10)
-	rw.weightsQueue = make(chan *LayerMatrix, 10)
+	rw.activationsQueue = make(chan *LayerMatrix, 5)
+	rw.biasesQueue = make(chan *LayerMatrix, 5)
+	rw.weightsQueue = make(chan *LayerMatrix, 5)
+	rw.stateQueue = make(chan int, 2)
 }
 
 func (rw *RemoteControl) UpdateActivations(l int, a *mat.Dense) {
@@ -83,6 +85,13 @@ func (rw *RemoteControl) UpdateWeights(l int, weights *mat.Dense) {
 	}
 }
 
+func (rw *RemoteControl) UpdateState(state int) {
+	select {
+	case rw.stateQueue <- state:
+	default:
+	}
+}
+
 func NewLayerMatrix(l int, dense *mat.Dense, contentType LayerMatrix_ContentType) (matrix *LayerMatrix) {
 	buffer, err := dense.MarshalBinary()
 	if err != nil {
@@ -148,6 +157,23 @@ func (rw *RemoteControl) Weights(_ *None, srv RemoteControl_WeightsServer) error
 	}
 }
 
+func (rw *RemoteControl) State(_ *None, srv RemoteControl_StateServer) error {
+	ctx := srv.Context()
+	for {
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		default:
+		}
+		state := <-rw.stateQueue
+		msg := &NetworkState{
+			State: NetworkState_State(state),
+		}
+		fmt.Printf("Send state %v %v\n", msg, state)
+		srv.Send(msg)
+	}
+}
+
 func (rw *RemoteControl) Predict(context.Context, *Matrix) (*Matrix, error) {
 	return nil, status.Error(codes.Unimplemented, "Not implemented")
 }
@@ -177,7 +203,8 @@ func (rw *RemoteControl) DummyStart(context.Context, *None) (*None, error) {
 		rw.nn.SaveState(outFile)
 		outFile.Close()
 
-		time.Sleep(5 * time.Second)
+		rw.UpdateState(neuralnetworkbase.StateLearning)
+		defer rw.UpdateState(neuralnetworkbase.StateIdle)
 		failCount := 0
 		teacher.Reset()
 		for teacher.NextValidator() {
@@ -199,6 +226,7 @@ func (rw *RemoteControl) DummyStart(context.Context, *None) (*None, error) {
 		fmt.Printf("Fail count: %v\n\n", failCount)
 		failCount = 0
 		teacher.Reset()
+		rw.UpdateState(neuralnetworkbase.StateIdle)
 	}()
 
 	return &None{}, nil

+ 173 - 29
neuralnetwork/remotecontrol/remotecontrol.pb.go

@@ -24,6 +24,40 @@ var _ = math.Inf
 // proto package needs to be updated.
 const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
 
+type NetworkState_State int32
+
+const (
+	NetworkState_None       NetworkState_State = 0
+	NetworkState_Idle       NetworkState_State = 1
+	NetworkState_Learning   NetworkState_State = 2
+	NetworkState_Validation NetworkState_State = 3
+	NetworkState_Predict    NetworkState_State = 4
+)
+
+var NetworkState_State_name = map[int32]string{
+	0: "None",
+	1: "Idle",
+	2: "Learning",
+	3: "Validation",
+	4: "Predict",
+}
+
+var NetworkState_State_value = map[string]int32{
+	"None":       0,
+	"Idle":       1,
+	"Learning":   2,
+	"Validation": 3,
+	"Predict":    4,
+}
+
+func (x NetworkState_State) String() string {
+	return proto.EnumName(NetworkState_State_name, int32(x))
+}
+
+func (NetworkState_State) EnumDescriptor() ([]byte, []int) {
+	return fileDescriptor_9e7470c0107e56c6, []int{0, 0}
+}
+
 type LayerMatrix_ContentType int32
 
 const (
@@ -49,7 +83,46 @@ func (x LayerMatrix_ContentType) String() string {
 }
 
 func (LayerMatrix_ContentType) EnumDescriptor() ([]byte, []int) {
-	return fileDescriptor_9e7470c0107e56c6, []int{1, 0}
+	return fileDescriptor_9e7470c0107e56c6, []int{2, 0}
+}
+
+type NetworkState struct {
+	State                NetworkState_State `protobuf:"varint,1,opt,name=state,proto3,enum=remotecontrol.NetworkState_State" json:"state,omitempty"`
+	XXX_NoUnkeyedLiteral struct{}           `json:"-"`
+	XXX_unrecognized     []byte             `json:"-"`
+	XXX_sizecache        int32              `json:"-"`
+}
+
+func (m *NetworkState) Reset()         { *m = NetworkState{} }
+func (m *NetworkState) String() string { return proto.CompactTextString(m) }
+func (*NetworkState) ProtoMessage()    {}
+func (*NetworkState) Descriptor() ([]byte, []int) {
+	return fileDescriptor_9e7470c0107e56c6, []int{0}
+}
+
+func (m *NetworkState) XXX_Unmarshal(b []byte) error {
+	return xxx_messageInfo_NetworkState.Unmarshal(m, b)
+}
+func (m *NetworkState) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+	return xxx_messageInfo_NetworkState.Marshal(b, m, deterministic)
+}
+func (m *NetworkState) XXX_Merge(src proto.Message) {
+	xxx_messageInfo_NetworkState.Merge(m, src)
+}
+func (m *NetworkState) XXX_Size() int {
+	return xxx_messageInfo_NetworkState.Size(m)
+}
+func (m *NetworkState) XXX_DiscardUnknown() {
+	xxx_messageInfo_NetworkState.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_NetworkState proto.InternalMessageInfo
+
+func (m *NetworkState) GetState() NetworkState_State {
+	if m != nil {
+		return m.State
+	}
+	return NetworkState_None
 }
 
 type Matrix struct {
@@ -63,7 +136,7 @@ func (m *Matrix) Reset()         { *m = Matrix{} }
 func (m *Matrix) String() string { return proto.CompactTextString(m) }
 func (*Matrix) ProtoMessage()    {}
 func (*Matrix) Descriptor() ([]byte, []int) {
-	return fileDescriptor_9e7470c0107e56c6, []int{0}
+	return fileDescriptor_9e7470c0107e56c6, []int{1}
 }
 
 func (m *Matrix) XXX_Unmarshal(b []byte) error {
@@ -104,7 +177,7 @@ func (m *LayerMatrix) Reset()         { *m = LayerMatrix{} }
 func (m *LayerMatrix) String() string { return proto.CompactTextString(m) }
 func (*LayerMatrix) ProtoMessage()    {}
 func (*LayerMatrix) Descriptor() ([]byte, []int) {
-	return fileDescriptor_9e7470c0107e56c6, []int{1}
+	return fileDescriptor_9e7470c0107e56c6, []int{2}
 }
 
 func (m *LayerMatrix) XXX_Unmarshal(b []byte) error {
@@ -157,7 +230,7 @@ func (m *Configuration) Reset()         { *m = Configuration{} }
 func (m *Configuration) String() string { return proto.CompactTextString(m) }
 func (*Configuration) ProtoMessage()    {}
 func (*Configuration) Descriptor() ([]byte, []int) {
-	return fileDescriptor_9e7470c0107e56c6, []int{2}
+	return fileDescriptor_9e7470c0107e56c6, []int{3}
 }
 
 func (m *Configuration) XXX_Unmarshal(b []byte) error {
@@ -195,7 +268,7 @@ func (m *None) Reset()         { *m = None{} }
 func (m *None) String() string { return proto.CompactTextString(m) }
 func (*None) ProtoMessage()    {}
 func (*None) Descriptor() ([]byte, []int) {
-	return fileDescriptor_9e7470c0107e56c6, []int{3}
+	return fileDescriptor_9e7470c0107e56c6, []int{4}
 }
 
 func (m *None) XXX_Unmarshal(b []byte) error {
@@ -217,7 +290,9 @@ func (m *None) XXX_DiscardUnknown() {
 var xxx_messageInfo_None proto.InternalMessageInfo
 
 func init() {
+	proto.RegisterEnum("remotecontrol.NetworkState_State", NetworkState_State_name, NetworkState_State_value)
 	proto.RegisterEnum("remotecontrol.LayerMatrix_ContentType", LayerMatrix_ContentType_name, LayerMatrix_ContentType_value)
+	proto.RegisterType((*NetworkState)(nil), "remotecontrol.NetworkState")
 	proto.RegisterType((*Matrix)(nil), "remotecontrol.Matrix")
 	proto.RegisterType((*LayerMatrix)(nil), "remotecontrol.LayerMatrix")
 	proto.RegisterType((*Configuration)(nil), "remotecontrol.Configuration")
@@ -227,28 +302,34 @@ func init() {
 func init() { proto.RegisterFile("remotecontrol.proto", fileDescriptor_9e7470c0107e56c6) }
 
 var fileDescriptor_9e7470c0107e56c6 = []byte{
-	// 331 bytes of a gzipped FileDescriptorProto
-	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x92, 0x41, 0x4f, 0xc2, 0x40,
-	0x10, 0x85, 0xbb, 0xa0, 0x25, 0x99, 0x15, 0xad, 0x83, 0x1a, 0x42, 0x3c, 0x34, 0x9b, 0x68, 0xb8,
-	0x48, 0x0c, 0x1e, 0xd4, 0x83, 0x26, 0xc2, 0x41, 0x0f, 0x6a, 0x4c, 0x35, 0xf1, 0x5c, 0x61, 0xc5,
-	0x4d, 0x68, 0x97, 0xec, 0x0e, 0x46, 0xfc, 0x15, 0xfe, 0x39, 0xff, 0x8f, 0xe9, 0x82, 0x48, 0x49,
-	0xbd, 0x70, 0xeb, 0xbc, 0xbc, 0xf7, 0xe6, 0xeb, 0xb4, 0x50, 0x33, 0x32, 0xd1, 0x24, 0x7b, 0x3a,
-	0x25, 0xa3, 0x87, 0xad, 0x91, 0xd1, 0xa4, 0xb1, 0x9a, 0x13, 0x45, 0x08, 0xfe, 0x5d, 0x4c, 0x46,
-	0x7d, 0xe0, 0x1e, 0xf8, 0x89, 0x7b, 0xaa, 0xb3, 0x90, 0x35, 0x37, 0xa2, 0xd9, 0x24, 0xbe, 0x19,
-	0xf0, 0xdb, 0x78, 0x22, 0xcd, 0xcc, 0x77, 0x03, 0x3c, 0x0b, 0xcb, 0x94, 0x9e, 0x26, 0x23, 0xe9,
-	0xcc, 0x9b, 0xed, 0xc3, 0x56, 0x7e, 0xd7, 0x42, 0xa0, 0xd5, 0xfd, 0x73, 0x47, 0x8b, 0x51, 0xdc,
-	0x81, 0xf5, 0x61, 0xe6, 0xab, 0x97, 0x42, 0xd6, 0xdc, 0x8e, 0xa6, 0x03, 0x1e, 0xcd, 0x39, 0xca,
-	0x21, 0x6b, 0xf2, 0xf6, 0xee, 0x52, 0xf5, 0xb4, 0x75, 0x8e, 0x77, 0x0a, 0x7c, 0x61, 0x01, 0x6e,
-	0x01, 0xbf, 0xea, 0x91, 0x7a, 0x8f, 0x49, 0xe9, 0xd4, 0x06, 0x1e, 0x72, 0xa8, 0x3c, 0x4b, 0x35,
-	0x78, 0x23, 0x1b, 0x30, 0x04, 0xf0, 0x3b, 0x2a, 0xb6, 0xd2, 0x06, 0x25, 0x71, 0x00, 0xd5, 0xae,
-	0x4e, 0x5f, 0xd5, 0x60, 0x6c, 0x9c, 0x39, 0xc3, 0xb1, 0xea, 0x53, 0xda, 0x3a, 0x0b, 0xcb, 0x19,
-	0x8e, 0x1b, 0x84, 0x0f, 0x6b, 0xf7, 0x3a, 0x95, 0xed, 0xaf, 0x32, 0x54, 0x23, 0x07, 0xd2, 0x9d,
-	0x82, 0xe0, 0x35, 0x04, 0x03, 0x49, 0xf9, 0x8e, 0xda, 0x12, 0x6c, 0x16, 0x6d, 0xec, 0x2f, 0x89,
-	0xb9, 0x88, 0xf0, 0xb0, 0x93, 0x63, 0x2e, 0xee, 0x68, 0xfc, 0x7f, 0x60, 0xe1, 0x1d, 0x33, 0xbc,
-	0xf8, 0x7d, 0xb3, 0xd5, 0xe2, 0x97, 0xf3, 0x2b, 0xad, 0x96, 0x3f, 0x87, 0xca, 0x83, 0x91, 0x7d,
-	0xd5, 0x23, 0x2c, 0xfe, 0x5e, 0x8d, 0x62, 0x59, 0x78, 0x78, 0x06, 0xd0, 0x1f, 0x27, 0xc9, 0xe4,
-	0x91, 0x62, 0x43, 0xc5, 0xdb, 0x8b, 0x44, 0xe1, 0xbd, 0xf8, 0xee, 0x8f, 0x3e, 0xf9, 0x09, 0x00,
-	0x00, 0xff, 0xff, 0x43, 0xc0, 0x92, 0x1e, 0xe8, 0x02, 0x00, 0x00,
+	// 422 bytes of a gzipped FileDescriptorProto
+	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x53, 0x4f, 0xef, 0xd2, 0x40,
+	0x10, 0xed, 0xd2, 0x1f, 0xfd, 0xfd, 0x32, 0x05, 0x5c, 0x17, 0x35, 0x04, 0x3d, 0xd4, 0x4d, 0x34,
+	0x5c, 0x6c, 0x0c, 0x1e, 0xd0, 0x83, 0x26, 0x42, 0xe2, 0x9f, 0x04, 0x89, 0x29, 0x46, 0xcf, 0x2b,
+	0x5d, 0xeb, 0xc6, 0x76, 0x97, 0x6c, 0x17, 0x15, 0xbf, 0x84, 0x1f, 0xcd, 0x93, 0xdf, 0xc7, 0x74,
+	0x0b, 0xd8, 0x36, 0xd5, 0x03, 0x97, 0x66, 0x66, 0xf2, 0xde, 0x9b, 0x37, 0xb3, 0x53, 0x18, 0x6a,
+	0x9e, 0x29, 0xc3, 0x37, 0x4a, 0x1a, 0xad, 0xd2, 0x70, 0xab, 0x95, 0x51, 0xa4, 0x5f, 0x2b, 0xd2,
+	0x9f, 0x08, 0x7a, 0x2b, 0x6e, 0xbe, 0x29, 0xfd, 0x65, 0x6d, 0x98, 0xe1, 0x64, 0x06, 0xdd, 0xbc,
+	0x08, 0x46, 0x28, 0x40, 0x93, 0xc1, 0xf4, 0x6e, 0x58, 0x17, 0xa9, 0x62, 0x43, 0xfb, 0x8d, 0x4a,
+	0x3c, 0x7d, 0x01, 0xdd, 0x52, 0xe1, 0x0a, 0x2e, 0x56, 0x4a, 0x72, 0xec, 0x14, 0xd1, 0xeb, 0x38,
+	0xe5, 0x18, 0x91, 0x1e, 0x5c, 0x2d, 0x39, 0xd3, 0x52, 0xc8, 0x04, 0x77, 0xc8, 0x00, 0xe0, 0x3d,
+	0x4b, 0x45, 0xcc, 0x8c, 0x50, 0x12, 0xbb, 0xc4, 0x87, 0xcb, 0xb7, 0x9a, 0xc7, 0x62, 0x63, 0xf0,
+	0x05, 0x0d, 0xc0, 0x7b, 0xc3, 0x8c, 0x16, 0xdf, 0xc9, 0x2d, 0xf0, 0x32, 0x1b, 0x59, 0x2f, 0xbd,
+	0xe8, 0x90, 0xd1, 0xdf, 0x08, 0xfc, 0x25, 0xdb, 0x73, 0x7d, 0xc0, 0xbd, 0x02, 0xbf, 0xb0, 0xc7,
+	0xa5, 0x79, 0xb7, 0xdf, 0x1e, 0x8d, 0xdf, 0x6f, 0x18, 0xaf, 0x10, 0xc2, 0xc5, 0x5f, 0x74, 0x54,
+	0xa5, 0x92, 0x1b, 0xd0, 0x4d, 0x0b, 0xdc, 0xa8, 0x13, 0xa0, 0xc9, 0xf5, 0xa8, 0x4c, 0xc8, 0x83,
+	0x93, 0x0f, 0x37, 0x40, 0x13, 0x7f, 0x7a, 0xb3, 0x21, 0x5d, 0xaa, 0x9e, 0xec, 0xcd, 0xc0, 0xaf,
+	0x34, 0x20, 0xd7, 0xc0, 0x7f, 0xbe, 0x31, 0xe2, 0xab, 0x1d, 0x36, 0xc7, 0x4e, 0x31, 0xed, 0x07,
+	0x2e, 0x92, 0xcf, 0x26, 0xc7, 0x88, 0x00, 0x78, 0x73, 0xc1, 0x72, 0x9e, 0xe3, 0x0e, 0xbd, 0x07,
+	0xfd, 0x85, 0x92, 0x9f, 0x44, 0xb2, 0xd3, 0x16, 0x5c, 0xd8, 0xc9, 0xc5, 0x0f, 0x9e, 0x8f, 0x50,
+	0xe0, 0x16, 0x76, 0x6c, 0x42, 0xbd, 0x72, 0xbf, 0xd3, 0x5f, 0x2e, 0xf4, 0x23, 0x6b, 0x64, 0x51,
+	0x1a, 0x21, 0x2f, 0x01, 0x27, 0xdc, 0xd4, 0x35, 0x86, 0xcd, 0x07, 0x54, 0x92, 0x8f, 0xef, 0x34,
+	0x8a, 0x35, 0x0a, 0x75, 0xc8, 0xbc, 0xe6, 0xb9, 0x5d, 0x63, 0xfc, 0xef, 0x05, 0x53, 0xe7, 0x21,
+	0x22, 0x4f, 0x8f, 0x93, 0x9d, 0x47, 0x7f, 0x76, 0xda, 0xd2, 0xb9, 0xed, 0x0f, 0xe7, 0xd8, 0xca,
+	0xbe, 0xfd, 0x9f, 0xb3, 0xb6, 0xf4, 0x27, 0x70, 0xb9, 0x2d, 0x4f, 0x92, 0xb4, 0x3f, 0xf7, 0xb8,
+	0xbd, 0x4c, 0x1d, 0xf2, 0x18, 0x20, 0xde, 0x65, 0xd9, 0x7e, 0x6d, 0x98, 0x36, 0xed, 0xed, 0xdb,
+	0x8a, 0xd4, 0xf9, 0xe8, 0xd9, 0x5f, 0xf4, 0xd1, 0x9f, 0x00, 0x00, 0x00, 0xff, 0xff, 0x5f, 0x9d,
+	0x2f, 0xb4, 0xb9, 0x03, 0x00, 0x00,
 }
 
 // Reference imports to suppress errors if they are not otherwise used.
@@ -267,6 +348,7 @@ type RemoteControlClient interface {
 	Activations(ctx context.Context, in *None, opts ...grpc.CallOption) (RemoteControl_ActivationsClient, error)
 	Biases(ctx context.Context, in *None, opts ...grpc.CallOption) (RemoteControl_BiasesClient, error)
 	Weights(ctx context.Context, in *None, opts ...grpc.CallOption) (RemoteControl_WeightsClient, error)
+	State(ctx context.Context, in *None, opts ...grpc.CallOption) (RemoteControl_StateClient, error)
 	Predict(ctx context.Context, in *Matrix, opts ...grpc.CallOption) (*Matrix, error)
 	DummyStart(ctx context.Context, in *None, opts ...grpc.CallOption) (*None, error)
 }
@@ -384,9 +466,41 @@ func (x *remoteControlWeightsClient) Recv() (*LayerMatrix, error) {
 	return m, nil
 }
 
+func (c *remoteControlClient) State(ctx context.Context, in *None, opts ...grpc.CallOption) (RemoteControl_StateClient, error) {
+	stream, err := c.cc.NewStream(ctx, &_RemoteControl_serviceDesc.Streams[3], "/remotecontrol.RemoteControl/State", opts...)
+	if err != nil {
+		return nil, err
+	}
+	x := &remoteControlStateClient{stream}
+	if err := x.ClientStream.SendMsg(in); err != nil {
+		return nil, err
+	}
+	if err := x.ClientStream.CloseSend(); err != nil {
+		return nil, err
+	}
+	return x, nil
+}
+
+type RemoteControl_StateClient interface {
+	Recv() (*NetworkState, error)
+	grpc.ClientStream
+}
+
+type remoteControlStateClient struct {
+	grpc.ClientStream
+}
+
+func (x *remoteControlStateClient) Recv() (*NetworkState, error) {
+	m := new(NetworkState)
+	if err := x.ClientStream.RecvMsg(m); err != nil {
+		return nil, err
+	}
+	return m, nil
+}
+
 func (c *remoteControlClient) Predict(ctx context.Context, in *Matrix, opts ...grpc.CallOption) (*Matrix, error) {
 	out := new(Matrix)
-	err := c.cc.Invoke(ctx, "/remotecontrol.RemoteControl/Predict", in, out, opts...)
+	err := c.cc.Invoke(ctx, "/remotecontrol.RemoteControl/predict", in, out, opts...)
 	if err != nil {
 		return nil, err
 	}
@@ -408,6 +522,7 @@ type RemoteControlServer interface {
 	Activations(*None, RemoteControl_ActivationsServer) error
 	Biases(*None, RemoteControl_BiasesServer) error
 	Weights(*None, RemoteControl_WeightsServer) error
+	State(*None, RemoteControl_StateServer) error
 	Predict(context.Context, *Matrix) (*Matrix, error)
 	DummyStart(context.Context, *None) (*None, error)
 }
@@ -428,6 +543,9 @@ func (*UnimplementedRemoteControlServer) Biases(req *None, srv RemoteControl_Bia
 func (*UnimplementedRemoteControlServer) Weights(req *None, srv RemoteControl_WeightsServer) error {
 	return status.Errorf(codes.Unimplemented, "method Weights not implemented")
 }
+func (*UnimplementedRemoteControlServer) State(req *None, srv RemoteControl_StateServer) error {
+	return status.Errorf(codes.Unimplemented, "method State not implemented")
+}
 func (*UnimplementedRemoteControlServer) Predict(ctx context.Context, req *Matrix) (*Matrix, error) {
 	return nil, status.Errorf(codes.Unimplemented, "method Predict not implemented")
 }
@@ -520,6 +638,27 @@ func (x *remoteControlWeightsServer) Send(m *LayerMatrix) error {
 	return x.ServerStream.SendMsg(m)
 }
 
+func _RemoteControl_State_Handler(srv interface{}, stream grpc.ServerStream) error {
+	m := new(None)
+	if err := stream.RecvMsg(m); err != nil {
+		return err
+	}
+	return srv.(RemoteControlServer).State(m, &remoteControlStateServer{stream})
+}
+
+type RemoteControl_StateServer interface {
+	Send(*NetworkState) error
+	grpc.ServerStream
+}
+
+type remoteControlStateServer struct {
+	grpc.ServerStream
+}
+
+func (x *remoteControlStateServer) Send(m *NetworkState) error {
+	return x.ServerStream.SendMsg(m)
+}
+
 func _RemoteControl_Predict_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
 	in := new(Matrix)
 	if err := dec(in); err != nil {
@@ -565,7 +704,7 @@ var _RemoteControl_serviceDesc = grpc.ServiceDesc{
 			Handler:    _RemoteControl_GetConfiguration_Handler,
 		},
 		{
-			MethodName: "Predict",
+			MethodName: "predict",
 			Handler:    _RemoteControl_Predict_Handler,
 		},
 		{
@@ -589,6 +728,11 @@ var _RemoteControl_serviceDesc = grpc.ServiceDesc{
 			Handler:       _RemoteControl_Weights_Handler,
 			ServerStreams: true,
 		},
+		{
+			StreamName:    "State",
+			Handler:       _RemoteControl_State_Handler,
+			ServerStreams: true,
+		},
 	},
 	Metadata: "remotecontrol.proto",
 }

+ 13 - 1
neuralnetwork/remotecontrol/remotecontrol.proto

@@ -27,6 +27,17 @@
 
 package remotecontrol;
 
+message NetworkState {
+    enum State {
+        None = 0;
+        Idle = 1;
+        Learning = 2;
+        Validation = 3;
+        Predict = 4;
+    }
+    State state = 1;
+}
+
 message Matrix {
     bytes matrix = 1;
 }
@@ -54,6 +65,7 @@ service RemoteControl {
     rpc Activations(None) returns (stream LayerMatrix) {}
     rpc Biases(None) returns (stream LayerMatrix) {}
     rpc Weights(None) returns (stream LayerMatrix) {}
-    rpc Predict(Matrix) returns (Matrix) {}
+    rpc State(None) returns (stream NetworkState) {}
+    rpc predict(Matrix) returns (Matrix) {}
     rpc dummyStart(None) returns (None) {}
 }