Forráskód Böngészése

Add retrain button and functionality for handwriting

Alexey Edelev 5 éve
szülő
commit
3363beb499

+ 1 - 0
handwriting/handwriting.proto

@@ -44,4 +44,5 @@ message None {}
 service Handwriting {
     rpc recognize(Matrix) returns (Result) {}
     rpc setNeuralNetworkData(NeuralNetworkRaw) returns (None) {}
+    rpc reTrain(None) returns (None) {}
 }

+ 10 - 1
handwriting/handwriting/handwriting.go

@@ -7,6 +7,7 @@ import (
 	"net"
 
 	neuralnetwork "../../neuralnetwork/neuralnetwork"
+	training "../../neuralnetwork/training"
 	"gonum.org/v1/gonum/mat"
 	grpc "google.golang.org/grpc"
 )
@@ -17,7 +18,7 @@ type HandwritingService struct {
 
 func NewHandwritingService() (hws *HandwritingService) {
 	hws = &HandwritingService{}
-	hws.nn, _ = neuralnetwork.NewNeuralNetwork([]int{784, 24, 24, 10}, neuralnetwork.NewRPropInitializer(neuralnetwork.RPropConfig{
+	hws.nn, _ = neuralnetwork.NewNeuralNetwork([]int{784, 16, 16, 10}, neuralnetwork.NewRPropInitializer(neuralnetwork.RPropConfig{
 		NuPlus:   1.2,
 		NuMinus:  0.5,
 		DeltaMax: 50.0,
@@ -41,6 +42,14 @@ func (hws *HandwritingService) SetNeuralNetworkData(ctx context.Context, nnRaw *
 	return &None{}, nil
 }
 
+func (hws *HandwritingService) ReTrain(context.Context, *None) (*None, error) {
+	fmt.Println("ReTrain")
+	trainer := training.NewMNISTReader("./mnist.data", "./mnist.labels")
+	hws.nn.Train(trainer, 100)
+	fmt.Println("ReTrain finished")
+	return &None{}, nil
+}
+
 func (hws *HandwritingService) Run() {
 	grpcServer := grpc.NewServer()
 	RegisterHandwritingServer(grpcServer, hws)

+ 52 - 15
handwriting/handwriting/handwriting.pb.go

@@ -182,21 +182,22 @@ func init() {
 func init() { proto.RegisterFile("handwriting.proto", fileDescriptor_d3287f4c1e120e43) }
 
 var fileDescriptor_d3287f4c1e120e43 = []byte{
-	// 213 bytes of a gzipped FileDescriptorProto
-	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0xcc, 0x48, 0xcc, 0x4b,
-	0x29, 0x2f, 0xca, 0x2c, 0xc9, 0xcc, 0x4b, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x46,
-	0x12, 0x52, 0x92, 0xe1, 0x62, 0xf3, 0x4d, 0x2c, 0x29, 0xca, 0xac, 0x10, 0x12, 0xe2, 0x62, 0x49,
-	0x49, 0x2c, 0x49, 0x94, 0x60, 0x54, 0x60, 0xd6, 0x60, 0x0c, 0x02, 0xb3, 0x95, 0x8c, 0xb8, 0xd8,
-	0x82, 0x52, 0x8b, 0x4b, 0x73, 0x4a, 0x84, 0x34, 0xb8, 0xf8, 0x8b, 0xc0, 0x2c, 0xe7, 0x8c, 0xc4,
-	0xa2, 0xc4, 0xe4, 0x92, 0xd4, 0x22, 0x09, 0x46, 0x05, 0x46, 0x0d, 0xde, 0x20, 0x74, 0x61, 0x25,
-	0x35, 0x2e, 0x01, 0xbf, 0xd4, 0xd2, 0xa2, 0xc4, 0x1c, 0xbf, 0xd4, 0x92, 0xf2, 0xfc, 0xa2, 0xec,
-	0xa0, 0xc4, 0x72, 0x24, 0xb3, 0x19, 0x35, 0x78, 0xa0, 0x66, 0xb3, 0x71, 0xb1, 0xf8, 0xe5, 0xe7,
-	0xa5, 0x1a, 0x4d, 0x62, 0xe4, 0xe2, 0xf6, 0x40, 0xb8, 0x48, 0xc8, 0x9c, 0x8b, 0xb3, 0x28, 0x35,
-	0x39, 0x3f, 0x3d, 0x2f, 0xb3, 0x2a, 0x55, 0x48, 0x58, 0x0f, 0xd9, 0xfd, 0x10, 0x97, 0x4a, 0xa1,
-	0x0a, 0x42, 0x1c, 0xa8, 0xc4, 0x20, 0xe4, 0xc5, 0x25, 0x52, 0x9c, 0x5a, 0x82, 0x62, 0xb7, 0x4b,
-	0x62, 0x49, 0xa2, 0x90, 0x2c, 0x8a, 0x72, 0x74, 0xb7, 0x49, 0x09, 0xa2, 0x4a, 0xe7, 0xe7, 0xa5,
-	0x2a, 0x31, 0x24, 0xb1, 0x81, 0x83, 0xca, 0x18, 0x10, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x34, 0x3f,
-	0x3f, 0x3f, 0x01, 0x00, 0x00,
+	// 228 bytes of a gzipped FileDescriptorProto
+	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x6c, 0x90, 0xc1, 0x4a, 0x03, 0x41,
+	0x0c, 0x86, 0x3b, 0x58, 0x46, 0x4c, 0x15, 0x6d, 0xf4, 0x50, 0x8a, 0x42, 0x99, 0x83, 0xcc, 0xa9,
+	0x60, 0x3d, 0xf8, 0x00, 0x7a, 0x10, 0xc1, 0x3d, 0x0c, 0xbe, 0x40, 0x6c, 0x43, 0x3b, 0x58, 0x66,
+	0x24, 0x4d, 0x59, 0xf1, 0xfd, 0x7c, 0x2f, 0x61, 0x57, 0x70, 0x67, 0xf5, 0xf6, 0xf3, 0xff, 0xc9,
+	0x9f, 0x8f, 0xc0, 0x78, 0x43, 0x69, 0x55, 0x4b, 0xd4, 0x98, 0xd6, 0xf3, 0x77, 0xc9, 0x9a, 0x71,
+	0xd4, 0xb1, 0xdc, 0x25, 0xd8, 0x67, 0x52, 0x89, 0x1f, 0x88, 0x30, 0x5c, 0x91, 0xd2, 0xc4, 0xcc,
+	0x0e, 0xbc, 0x09, 0x8d, 0x76, 0x0b, 0xb0, 0x81, 0x77, 0xfb, 0xad, 0xa2, 0x87, 0x53, 0x69, 0xd4,
+	0xfd, 0x86, 0x84, 0x96, 0xca, 0x32, 0x31, 0x33, 0xe3, 0x4f, 0x42, 0xdf, 0x76, 0xd7, 0x70, 0x56,
+	0xf1, 0x5e, 0x68, 0x5b, 0xb1, 0xd6, 0x59, 0xde, 0x02, 0xd5, 0x9d, 0x6e, 0xe3, 0x8f, 0x7f, 0xba,
+	0x2d, 0x0c, 0xab, 0x9c, 0x78, 0xf1, 0x65, 0x60, 0xf4, 0xf8, 0x4b, 0x84, 0x77, 0x70, 0x24, 0xbc,
+	0xcc, 0xeb, 0x14, 0x3f, 0x19, 0xcf, 0xe7, 0x5d, 0xfe, 0x96, 0x74, 0x5a, 0x9a, 0x2d, 0xa0, 0x1b,
+	0xe0, 0x13, 0x5c, 0xec, 0x58, 0x8b, 0xdb, 0x0f, 0xa4, 0x84, 0x57, 0xc5, 0x78, 0x9f, 0x6d, 0x3a,
+	0x2e, 0xe3, 0x9c, 0xd8, 0x0d, 0xf0, 0x06, 0x0e, 0x85, 0x5f, 0x84, 0x62, 0xc2, 0xbf, 0xf9, 0xbf,
+	0x2b, 0xaf, 0xb6, 0xf9, 0xee, 0xed, 0x77, 0x00, 0x00, 0x00, 0xff, 0xff, 0x46, 0x56, 0x5b, 0x3e,
+	0x72, 0x01, 0x00, 0x00,
 }
 
 // Reference imports to suppress errors if they are not otherwise used.
@@ -213,6 +214,7 @@ const _ = grpc.SupportPackageIsVersion4
 type HandwritingClient interface {
 	Recognize(ctx context.Context, in *Matrix, opts ...grpc.CallOption) (*Result, error)
 	SetNeuralNetworkData(ctx context.Context, in *NeuralNetworkRaw, opts ...grpc.CallOption) (*None, error)
+	ReTrain(ctx context.Context, in *None, opts ...grpc.CallOption) (*None, error)
 }
 
 type handwritingClient struct {
@@ -241,10 +243,20 @@ func (c *handwritingClient) SetNeuralNetworkData(ctx context.Context, in *Neural
 	return out, nil
 }
 
+func (c *handwritingClient) ReTrain(ctx context.Context, in *None, opts ...grpc.CallOption) (*None, error) {
+	out := new(None)
+	err := c.cc.Invoke(ctx, "/handwriting.Handwriting/reTrain", in, out, opts...)
+	if err != nil {
+		return nil, err
+	}
+	return out, nil
+}
+
 // HandwritingServer is the server API for Handwriting service.
 type HandwritingServer interface {
 	Recognize(context.Context, *Matrix) (*Result, error)
 	SetNeuralNetworkData(context.Context, *NeuralNetworkRaw) (*None, error)
+	ReTrain(context.Context, *None) (*None, error)
 }
 
 // UnimplementedHandwritingServer can be embedded to have forward compatible implementations.
@@ -257,6 +269,9 @@ func (*UnimplementedHandwritingServer) Recognize(ctx context.Context, req *Matri
 func (*UnimplementedHandwritingServer) SetNeuralNetworkData(ctx context.Context, req *NeuralNetworkRaw) (*None, error) {
 	return nil, status.Errorf(codes.Unimplemented, "method SetNeuralNetworkData not implemented")
 }
+func (*UnimplementedHandwritingServer) ReTrain(ctx context.Context, req *None) (*None, error) {
+	return nil, status.Errorf(codes.Unimplemented, "method ReTrain not implemented")
+}
 
 func RegisterHandwritingServer(s *grpc.Server, srv HandwritingServer) {
 	s.RegisterService(&_Handwriting_serviceDesc, srv)
@@ -298,6 +313,24 @@ func _Handwriting_SetNeuralNetworkData_Handler(srv interface{}, ctx context.Cont
 	return interceptor(ctx, in, info, handler)
 }
 
+func _Handwriting_ReTrain_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+	in := new(None)
+	if err := dec(in); err != nil {
+		return nil, err
+	}
+	if interceptor == nil {
+		return srv.(HandwritingServer).ReTrain(ctx, in)
+	}
+	info := &grpc.UnaryServerInfo{
+		Server:     srv,
+		FullMethod: "/handwriting.Handwriting/ReTrain",
+	}
+	handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+		return srv.(HandwritingServer).ReTrain(ctx, req.(*None))
+	}
+	return interceptor(ctx, in, info, handler)
+}
+
 var _Handwriting_serviceDesc = grpc.ServiceDesc{
 	ServiceName: "handwriting.Handwriting",
 	HandlerType: (*HandwritingServer)(nil),
@@ -310,6 +343,10 @@ var _Handwriting_serviceDesc = grpc.ServiceDesc{
 			MethodName: "setNeuralNetworkData",
 			Handler:    _Handwriting_SetNeuralNetworkData_Handler,
 		},
+		{
+			MethodName: "reTrain",
+			Handler:    _Handwriting_ReTrain_Handler,
+		},
 	},
 	Streams:  []grpc.StreamDesc{},
 	Metadata: "handwriting.proto",

+ 3 - 0
handwriting/handwritingui/handwritingengine.cpp

@@ -59,6 +59,9 @@ HandwritingEngine::HandwritingEngine(QObject *parent) : QObject(parent)
     m_matrix.setData(emptyData);
 }
 
+void HandwritingEngine::retrain() {
+    m_client->reTrain({});
+}
 
 void HandwritingEngine::recognize()
 {

+ 1 - 0
handwriting/handwritingui/handwritingengine.h

@@ -22,6 +22,7 @@ public:
     explicit HandwritingEngine(QObject *parent = nullptr);
     virtual ~HandwritingEngine() = default;
 
+    Q_INVOKABLE void retrain();
     Q_INVOKABLE void recognize();
     Q_INVOKABLE void setNeuralNetworkData(const QString &networkDataPath);
     Q_INVOKABLE void updateValue(int index, double value);

+ 7 - 0
handwriting/handwritingui/main.qml

@@ -120,6 +120,13 @@ ApplicationWindow {
                 }
             }
         }
+        Button {
+            id: tainButton
+            text: "Re-run trainig"
+            onClicked: {
+                hwengine.retrain()
+            }
+        }
         Button {
             id: clearButton
             text: "Clear"

+ 3 - 2
neuralnetwork/neuralnetwork/neuralnetwork.go

@@ -240,8 +240,9 @@ func (nn *NeuralNetwork) TrainOnline(trainer training.Trainer, epocs int) {
 }
 
 func (nn *NeuralNetwork) TrainBatch(trainer training.Trainer, epocs int) {
+	fmt.Printf("Start training in %v threads\n", 2*runtime.NumCPU())
 	for t := 0; t < epocs; t++ {
-		batchWorkers := nn.runBatchWorkers(runtime.NumCPU(), trainer)
+		batchWorkers := nn.runBatchWorkers(2*runtime.NumCPU(), trainer)
 
 		for l := 1; l < nn.LayerCount; l++ {
 			bGradient, ok := nn.BGradient[l].(BatchGradientDescent)
@@ -328,7 +329,7 @@ func (nn *NeuralNetwork) LoadState(reader io.Reader) {
 
 	for i := 0; i < nn.LayerCount; i++ {
 		nn.Sizes[i] = int(binary.LittleEndian.Uint32(sizeBuffer[i*4:]))
-		// fmt.Printf("LoadState: nn.Sizes[%d] %d \n", i, nn.Sizes[i])
+		fmt.Printf("LoadState: nn.Sizes[%d] %d \n", i, nn.Sizes[i])
 	}
 
 	nn.Weights = []*mat.Dense{&mat.Dense{}}