Procházet zdrojové kódy

Add RC start

- Add start button
- Make neural network cycle run on start button click
Alexey Edelev před 5 roky
rodič
revize
5db9584508

+ 12 - 1
gui/main.qml

@@ -74,7 +74,7 @@ ApplicationWindow {
 
                                 anchors.fill: parent
                                 radius: 15
-                                color: "transparent"
+                                color: "#00ff00"
                                 ColorAnimation {
                                     id: anim
                                     target: neuron
@@ -150,4 +150,15 @@ ApplicationWindow {
             }
         }
     }
+
+    Button {
+        id: start
+        text: "Start"
+        anchors.right: parent.right
+        anchors.top: parent.top
+        anchors.margins: 20
+        onClicked: {
+            visualizerModel.start();
+        }
+    }
 }

+ 6 - 0
gui/visualizermodel.cpp

@@ -114,3 +114,9 @@ LayerTrigger *VisualizerModel::weightTrigger(int layer)
     QQmlEngine::setObjectOwnership(trigger, QQmlEngine::CppOwnership);
     return trigger;
 }
+
+
+void VisualizerModel::start()
+{
+    m_client->dummyStart({});
+}

+ 1 - 1
gui/visualizermodel.h

@@ -59,7 +59,7 @@ public:
     Q_INVOKABLE ValueIndicator *weight(int layer, int row, int column);
     Q_INVOKABLE LayerTrigger *activationTrigger(int layer);
     Q_INVOKABLE LayerTrigger *weightTrigger(int layer);
-
+    Q_INVOKABLE void start();
 signals:
     void sizesChanged();
 

+ 0 - 48
neuralnetwork/main.go

@@ -1,14 +1,8 @@
 package main
 
 import (
-	"fmt"
-	"log"
-	"os"
-	"time"
-
 	neuralnetwork "./neuralnetworkbase"
 	remotecontrol "./remotecontrol"
-	teach "./teach"
 )
 
 func main() {
@@ -40,48 +34,6 @@ func main() {
 	// 	fmt.Printf("A before:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
 	// }
 
-	go func() {
-		// teacher := teach.NewMNISTReader("./minst.data", "./mnist.labels")
-		teacher := teach.NewTextDataReader("wine.data", 5)
-		nn.Teach(teacher, 500)
-
-		// for i := 0; i < nn.Count; i++ {
-		// 	if i > 0 {
-		// 		fmt.Printf("Weights after:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
-		// 		fmt.Printf("Biases after:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
-		// 		fmt.Printf("Z after:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
-		// 	}
-		// 	fmt.Printf("A after:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
-		// }
-
-		outFile, err := os.OpenFile("./data", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
-		if err != nil {
-			log.Fatal(err)
-		}
-		defer outFile.Close()
-		nn.SaveState(outFile)
-		outFile.Close()
-
-		time.Sleep(5 * time.Second)
-		failCount := 0
-		teacher.Reset()
-		for true {
-			if !teacher.NextValidator() {
-				fmt.Printf("Fail count: %v\n\n", failCount)
-				failCount = 0
-				teacher.Reset()
-			}
-			dataSet, expect := teacher.GetValidator()
-			index, _ := nn.Predict(dataSet)
-			//TODO: remove this is not used for visualization
-			time.Sleep(400 * time.Millisecond)
-			if expect.At(index, 0) != 1.0 {
-				failCount++
-				// fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
-			}
-		}
-	}()
-
 	// nn = &neuralnetwork.NeuralNetwork{}
 	// inFile, err := os.Open("./data")
 	// if err != nil {

+ 58 - 0
neuralnetwork/remotecontrol/remotecontrol.go

@@ -30,6 +30,9 @@ import (
 	fmt "fmt"
 	"log"
 	"net"
+	"os"
+	"sync"
+	"time"
 
 	"google.golang.org/grpc/codes"
 	status "google.golang.org/grpc/status"
@@ -37,6 +40,8 @@ import (
 	neuralnetworkbase "../neuralnetworkbase"
 	"gonum.org/v1/gonum/mat"
 	grpc "google.golang.org/grpc"
+
+	teach "../teach"
 )
 
 type RemoteControl struct {
@@ -44,6 +49,7 @@ type RemoteControl struct {
 	activationsQueue chan *LayerMatrix
 	biasesQueue      chan *LayerMatrix
 	weightsQueue     chan *LayerMatrix
+	mutex            sync.Mutex
 }
 
 func (rw *RemoteControl) Init(nn *neuralnetworkbase.NeuralNetwork) {
@@ -146,6 +152,58 @@ func (rw *RemoteControl) Predict(context.Context, *Matrix) (*Matrix, error) {
 	return nil, status.Error(codes.Unimplemented, "Not implemented")
 }
 
+func (rw *RemoteControl) DummyStart(context.Context, *None) (*None, error) {
+	go func() {
+		rw.mutex.Lock()
+		defer rw.mutex.Unlock()
+		// teacher := teach.NewMNISTReader("./minst.data", "./mnist.labels")
+		teacher := teach.NewTextDataReader("wine.data", 5)
+		rw.nn.Teach(teacher, 500)
+
+		// for i := 0; i < nn.Count; i++ {
+		// 	if i > 0 {
+		// 		fmt.Printf("Weights after:\n%v\n\n", mat.Formatted(nn.Weights[i], mat.Prefix(""), mat.Excerpt(0)))
+		// 		fmt.Printf("Biases after:\n%v\n\n", mat.Formatted(nn.Biases[i], mat.Prefix(""), mat.Excerpt(0)))
+		// 		fmt.Printf("Z after:\n%v\n\n", mat.Formatted(nn.Z[i], mat.Prefix(""), mat.Excerpt(0)))
+		// 	}
+		// 	fmt.Printf("A after:\n%v\n\n", mat.Formatted(nn.A[i], mat.Prefix(""), mat.Excerpt(0)))
+		// }
+
+		outFile, err := os.OpenFile("./data", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
+		if err != nil {
+			log.Fatal(err)
+		}
+		defer outFile.Close()
+		rw.nn.SaveState(outFile)
+		outFile.Close()
+
+		time.Sleep(5 * time.Second)
+		failCount := 0
+		teacher.Reset()
+		for teacher.NextValidator() {
+			dataSet, expect := teacher.GetValidator()
+			index, _ := rw.nn.Predict(dataSet)
+			//TODO: remove this is not used for visualization
+			time.Sleep(400 * time.Millisecond)
+			if expect.At(index, 0) != 1.0 {
+				failCount++
+				// fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
+			}
+			if !teacher.NextValidator() {
+				fmt.Printf("Fail count: %v\n\n", failCount)
+				failCount = 0
+				teacher.Reset()
+			}
+		}
+
+		fmt.Printf("Fail count: %v\n\n", failCount)
+		failCount = 0
+		teacher.Reset()
+	}()
+
+	return &None{}, nil
+}
+
 func (rw *RemoteControl) Run() {
 	grpcServer := grpc.NewServer()
 	RegisterRemoteControlServer(grpcServer, rw)

+ 58 - 21
neuralnetwork/remotecontrol/remotecontrol.pb.go

@@ -227,27 +227,28 @@ func init() {
 func init() { proto.RegisterFile("remotecontrol.proto", fileDescriptor_9e7470c0107e56c6) }
 
 var fileDescriptor_9e7470c0107e56c6 = []byte{
-	// 313 bytes of a gzipped FileDescriptorProto
-	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x92, 0xc1, 0x4b, 0xc3, 0x30,
-	0x14, 0xc6, 0x9b, 0x4d, 0x3b, 0x78, 0x71, 0x5a, 0x33, 0x95, 0x51, 0x3c, 0x94, 0x80, 0xd2, 0x8b,
-	0x45, 0xea, 0x41, 0x3c, 0x28, 0xb8, 0x1e, 0xf4, 0xa0, 0x22, 0x45, 0xf0, 0x5c, 0x6b, 0xac, 0x81,
-	0xad, 0x19, 0x49, 0x14, 0xe7, 0xff, 0xa9, 0x7f, 0x8f, 0x34, 0x9d, 0xb5, 0x2d, 0xf5, 0xb2, 0x5b,
-	0xde, 0xe3, 0xfb, 0xbe, 0xf7, 0xcb, 0x4b, 0x60, 0x24, 0xd9, 0x4c, 0x68, 0x96, 0x8a, 0x5c, 0x4b,
-	0x31, 0x0d, 0xe6, 0x52, 0x68, 0x41, 0x86, 0x8d, 0x26, 0xf5, 0xc0, 0xbe, 0x4d, 0xb4, 0xe4, 0x1f,
-	0x64, 0x0f, 0xec, 0x99, 0x39, 0x8d, 0x91, 0x87, 0xfc, 0x8d, 0x78, 0x59, 0xd1, 0x2f, 0x04, 0xf8,
-	0x26, 0x59, 0x30, 0xb9, 0xd4, 0x5d, 0x03, 0x2e, 0xcc, 0x2c, 0xd7, 0x0f, 0x8b, 0x39, 0x33, 0xe2,
-	0xcd, 0xf0, 0x30, 0x68, 0xce, 0xaa, 0x19, 0x82, 0xe8, 0x4f, 0x1d, 0xd7, 0xad, 0x64, 0x07, 0xd6,
-	0xa7, 0x85, 0x6e, 0xdc, 0xf3, 0x90, 0xbf, 0x1d, 0x97, 0x05, 0x39, 0xaa, 0x38, 0xfa, 0x1e, 0xf2,
-	0x71, 0xb8, 0xdb, 0x8a, 0x2e, 0x53, 0x2b, 0xbc, 0x53, 0xc0, 0xb5, 0x01, 0x64, 0x0b, 0xf0, 0x65,
-	0xaa, 0xf9, 0x7b, 0xa2, 0xb9, 0xc8, 0x95, 0x63, 0x11, 0x0c, 0x83, 0x47, 0xc6, 0xb3, 0x57, 0xad,
-	0x1c, 0x44, 0x00, 0xec, 0x09, 0x4f, 0x14, 0x53, 0x4e, 0x8f, 0x1e, 0xc0, 0x30, 0x12, 0xf9, 0x0b,
-	0xcf, 0xde, 0xa4, 0x11, 0x17, 0x38, 0x8a, 0x7f, 0x32, 0x35, 0x46, 0x5e, 0xbf, 0xc0, 0x31, 0x05,
-	0xb5, 0x61, 0xed, 0x4e, 0xe4, 0x2c, 0xfc, 0xee, 0xc1, 0x30, 0x36, 0x20, 0x51, 0x09, 0x42, 0xae,
-	0xc0, 0xc9, 0x98, 0x6e, 0x66, 0x8c, 0x5a, 0xb0, 0x85, 0xd5, 0xdd, 0x6f, 0x35, 0x1b, 0x16, 0x6a,
-	0x91, 0x49, 0x83, 0xb9, 0x3b, 0xc3, 0xfd, 0x7f, 0xc1, 0xd4, 0x3a, 0x46, 0xe4, 0xfc, 0xf7, 0x66,
-	0xab, 0xd9, 0x2f, 0xaa, 0x2d, 0xad, 0xe6, 0x3f, 0x83, 0xc1, 0xbd, 0x64, 0xcf, 0x3c, 0xd5, 0xa4,
-	0xfb, 0xbd, 0xdc, 0xee, 0x36, 0xb5, 0x9e, 0x6c, 0xf3, 0x2f, 0x4f, 0x7e, 0x02, 0x00, 0x00, 0xff,
-	0xff, 0x89, 0xa6, 0x82, 0x97, 0xae, 0x02, 0x00, 0x00,
+	// 331 bytes of a gzipped FileDescriptorProto
+	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x92, 0x41, 0x4f, 0xc2, 0x40,
+	0x10, 0x85, 0xbb, 0xa0, 0x25, 0x99, 0x15, 0xad, 0x83, 0x1a, 0x42, 0x3c, 0x34, 0x9b, 0x68, 0xb8,
+	0x48, 0x0c, 0x1e, 0xd4, 0x83, 0x26, 0xc2, 0x41, 0x0f, 0x6a, 0x4c, 0x35, 0xf1, 0x5c, 0x61, 0xc5,
+	0x4d, 0x68, 0x97, 0xec, 0x0e, 0x46, 0xfc, 0x15, 0xfe, 0x39, 0xff, 0x8f, 0xe9, 0x82, 0x48, 0x49,
+	0xbd, 0x70, 0xeb, 0xbc, 0xbc, 0xf7, 0xe6, 0xeb, 0xb4, 0x50, 0x33, 0x32, 0xd1, 0x24, 0x7b, 0x3a,
+	0x25, 0xa3, 0x87, 0xad, 0x91, 0xd1, 0xa4, 0xb1, 0x9a, 0x13, 0x45, 0x08, 0xfe, 0x5d, 0x4c, 0x46,
+	0x7d, 0xe0, 0x1e, 0xf8, 0x89, 0x7b, 0xaa, 0xb3, 0x90, 0x35, 0x37, 0xa2, 0xd9, 0x24, 0xbe, 0x19,
+	0xf0, 0xdb, 0x78, 0x22, 0xcd, 0xcc, 0x77, 0x03, 0x3c, 0x0b, 0xcb, 0x94, 0x9e, 0x26, 0x23, 0xe9,
+	0xcc, 0x9b, 0xed, 0xc3, 0x56, 0x7e, 0xd7, 0x42, 0xa0, 0xd5, 0xfd, 0x73, 0x47, 0x8b, 0x51, 0xdc,
+	0x81, 0xf5, 0x61, 0xe6, 0xab, 0x97, 0x42, 0xd6, 0xdc, 0x8e, 0xa6, 0x03, 0x1e, 0xcd, 0x39, 0xca,
+	0x21, 0x6b, 0xf2, 0xf6, 0xee, 0x52, 0xf5, 0xb4, 0x75, 0x8e, 0x77, 0x0a, 0x7c, 0x61, 0x01, 0x6e,
+	0x01, 0xbf, 0xea, 0x91, 0x7a, 0x8f, 0x49, 0xe9, 0xd4, 0x06, 0x1e, 0x72, 0xa8, 0x3c, 0x4b, 0x35,
+	0x78, 0x23, 0x1b, 0x30, 0x04, 0xf0, 0x3b, 0x2a, 0xb6, 0xd2, 0x06, 0x25, 0x71, 0x00, 0xd5, 0xae,
+	0x4e, 0x5f, 0xd5, 0x60, 0x6c, 0x9c, 0x39, 0xc3, 0xb1, 0xea, 0x53, 0xda, 0x3a, 0x0b, 0xcb, 0x19,
+	0x8e, 0x1b, 0x84, 0x0f, 0x6b, 0xf7, 0x3a, 0x95, 0xed, 0xaf, 0x32, 0x54, 0x23, 0x07, 0xd2, 0x9d,
+	0x82, 0xe0, 0x35, 0x04, 0x03, 0x49, 0xf9, 0x8e, 0xda, 0x12, 0x6c, 0x16, 0x6d, 0xec, 0x2f, 0x89,
+	0xb9, 0x88, 0xf0, 0xb0, 0x93, 0x63, 0x2e, 0xee, 0x68, 0xfc, 0x7f, 0x60, 0xe1, 0x1d, 0x33, 0xbc,
+	0xf8, 0x7d, 0xb3, 0xd5, 0xe2, 0x97, 0xf3, 0x2b, 0xad, 0x96, 0x3f, 0x87, 0xca, 0x83, 0x91, 0x7d,
+	0xd5, 0x23, 0x2c, 0xfe, 0x5e, 0x8d, 0x62, 0x59, 0x78, 0x78, 0x06, 0xd0, 0x1f, 0x27, 0xc9, 0xe4,
+	0x91, 0x62, 0x43, 0xc5, 0xdb, 0x8b, 0x44, 0xe1, 0xbd, 0xf8, 0xee, 0x8f, 0x3e, 0xf9, 0x09, 0x00,
+	0x00, 0xff, 0xff, 0x43, 0xc0, 0x92, 0x1e, 0xe8, 0x02, 0x00, 0x00,
 }
 
 // Reference imports to suppress errors if they are not otherwise used.
@@ -267,6 +268,7 @@ type RemoteControlClient interface {
 	Biases(ctx context.Context, in *None, opts ...grpc.CallOption) (RemoteControl_BiasesClient, error)
 	Weights(ctx context.Context, in *None, opts ...grpc.CallOption) (RemoteControl_WeightsClient, error)
 	Predict(ctx context.Context, in *Matrix, opts ...grpc.CallOption) (*Matrix, error)
+	DummyStart(ctx context.Context, in *None, opts ...grpc.CallOption) (*None, error)
 }
 
 type remoteControlClient struct {
@@ -391,6 +393,15 @@ func (c *remoteControlClient) Predict(ctx context.Context, in *Matrix, opts ...g
 	return out, nil
 }
 
+func (c *remoteControlClient) DummyStart(ctx context.Context, in *None, opts ...grpc.CallOption) (*None, error) {
+	out := new(None)
+	err := c.cc.Invoke(ctx, "/remotecontrol.RemoteControl/dummyStart", in, out, opts...)
+	if err != nil {
+		return nil, err
+	}
+	return out, nil
+}
+
 // RemoteControlServer is the server API for RemoteControl service.
 type RemoteControlServer interface {
 	GetConfiguration(context.Context, *None) (*Configuration, error)
@@ -398,6 +409,7 @@ type RemoteControlServer interface {
 	Biases(*None, RemoteControl_BiasesServer) error
 	Weights(*None, RemoteControl_WeightsServer) error
 	Predict(context.Context, *Matrix) (*Matrix, error)
+	DummyStart(context.Context, *None) (*None, error)
 }
 
 // UnimplementedRemoteControlServer can be embedded to have forward compatible implementations.
@@ -419,6 +431,9 @@ func (*UnimplementedRemoteControlServer) Weights(req *None, srv RemoteControl_We
 func (*UnimplementedRemoteControlServer) Predict(ctx context.Context, req *Matrix) (*Matrix, error) {
 	return nil, status.Errorf(codes.Unimplemented, "method Predict not implemented")
 }
+func (*UnimplementedRemoteControlServer) DummyStart(ctx context.Context, req *None) (*None, error) {
+	return nil, status.Errorf(codes.Unimplemented, "method DummyStart not implemented")
+}
 
 func RegisterRemoteControlServer(s *grpc.Server, srv RemoteControlServer) {
 	s.RegisterService(&_RemoteControl_serviceDesc, srv)
@@ -523,6 +538,24 @@ func _RemoteControl_Predict_Handler(srv interface{}, ctx context.Context, dec fu
 	return interceptor(ctx, in, info, handler)
 }
 
+func _RemoteControl_DummyStart_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+	in := new(None)
+	if err := dec(in); err != nil {
+		return nil, err
+	}
+	if interceptor == nil {
+		return srv.(RemoteControlServer).DummyStart(ctx, in)
+	}
+	info := &grpc.UnaryServerInfo{
+		Server:     srv,
+		FullMethod: "/remotecontrol.RemoteControl/DummyStart",
+	}
+	handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+		return srv.(RemoteControlServer).DummyStart(ctx, req.(*None))
+	}
+	return interceptor(ctx, in, info, handler)
+}
+
 var _RemoteControl_serviceDesc = grpc.ServiceDesc{
 	ServiceName: "remotecontrol.RemoteControl",
 	HandlerType: (*RemoteControlServer)(nil),
@@ -535,6 +568,10 @@ var _RemoteControl_serviceDesc = grpc.ServiceDesc{
 			MethodName: "Predict",
 			Handler:    _RemoteControl_Predict_Handler,
 		},
+		{
+			MethodName: "dummyStart",
+			Handler:    _RemoteControl_DummyStart_Handler,
+		},
 	},
 	Streams: []grpc.StreamDesc{
 		{

+ 1 - 0
neuralnetwork/remotecontrol/remotecontrol.proto

@@ -55,4 +55,5 @@ service RemoteControl {
     rpc Biases(None) returns (stream LayerMatrix) {}
     rpc Weights(None) returns (stream LayerMatrix) {}
     rpc Predict(Matrix) returns (Matrix) {}
+    rpc dummyStart(None) returns (None) {}
 }