Forráskód Böngészése

Implement handwriting service

Alexey Edelev 5 éve
szülő
commit
04e6cf61fd

+ 9 - 0
build.sh

@@ -27,4 +27,13 @@ rm -f $SNAKE_RPC_PATH/*.pb.go
 protoc -I$SNAKE_RPC_PATH --go_out=plugins=grpc:$SNAKE_RPC_PATH $SNAKE_RPC_PATH/snakesimulator.proto
 pushd snakesimulator
 go build -o $GOBIN/snakesimulator
+popd
+
+
+export HANDWRITING_RPC_PATH=$PWD/handwriting
+mkdir -p $HANDWRITING_RPC_PATH
+rm -f $HANDWRITING_RPC_PATH/handwriting/*.pb.go
+protoc -I$HANDWRITING_RPC_PATH --go_out=plugins=grpc:$HANDWRITING_RPC_PATH/handwriting $HANDWRITING_RPC_PATH/handwriting.proto
+pushd handwriting
+go build -o $GOBIN/handwriting
 popd

+ 56 - 0
handwriting/handwriting/handwriting.go

@@ -0,0 +1,56 @@
+package handwriting
+
+import (
+	"bytes"
+	context "context"
+	fmt "fmt"
+	"net"
+
+	neuralnetwork "../../neuralnetwork/neuralnetwork"
+	"gonum.org/v1/gonum/mat"
+	grpc "google.golang.org/grpc"
+)
+
+type HandwritingService struct {
+	nn *neuralnetwork.NeuralNetwork
+}
+
+func NewHandwritingService() (hws *HandwritingService) {
+	hws = &HandwritingService{}
+	hws.nn, _ = neuralnetwork.NewNeuralNetwork([]int{784, 24, 24, 10}, neuralnetwork.NewRPropInitializer(neuralnetwork.RPropConfig{
+		NuPlus:   1.2,
+		NuMinus:  0.5,
+		DeltaMax: 50.0,
+		DeltaMin: 0.000001,
+	}))
+	return
+}
+
+func (hws *HandwritingService) Recognize(ctx context.Context, matrix *Matrix) (*Result, error) {
+	fmt.Printf("Recognize %v size: %v\n", len(matrix.Data), hws.nn.Sizes[0])
+	dense := mat.NewDense(hws.nn.Sizes[0], 1, matrix.Data)
+	index, _ := hws.nn.Predict(dense)
+	fmt.Printf("Recognition result %v\n", index)
+	return &Result{ResultCharacter: uint32(index)}, nil
+}
+
+func (hws *HandwritingService) SetNeuralNetworkData(ctx context.Context, nnRaw *NeuralNetworkRaw) (*None, error) {
+	fmt.Println("SetNeuralNetworkData")
+	r := bytes.NewReader(nnRaw.Data)
+	hws.nn.LoadState(r)
+	return &None{}, nil
+}
+
+func (hws *HandwritingService) Run() {
+	grpcServer := grpc.NewServer()
+	RegisterHandwritingServer(grpcServer, hws)
+	lis, err := net.Listen("tcp", "localhost:65001")
+	if err != nil {
+		fmt.Printf("Failed to listen: %v\n", err)
+	}
+
+	fmt.Printf("Listen localhost:65001\n")
+	if err := grpcServer.Serve(lis); err != nil {
+		fmt.Printf("Failed to serve: %v\n", err)
+	}
+}

+ 316 - 0
handwriting/handwriting/handwriting.pb.go

@@ -0,0 +1,316 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: handwriting.proto
+
+package handwriting
+
+import (
+	context "context"
+	fmt "fmt"
+	proto "github.com/golang/protobuf/proto"
+	grpc "google.golang.org/grpc"
+	codes "google.golang.org/grpc/codes"
+	status "google.golang.org/grpc/status"
+	math "math"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
+
+type Matrix struct {
+	Data                 []float64 `protobuf:"fixed64,1,rep,packed,name=data,proto3" json:"data,omitempty"`
+	XXX_NoUnkeyedLiteral struct{}  `json:"-"`
+	XXX_unrecognized     []byte    `json:"-"`
+	XXX_sizecache        int32     `json:"-"`
+}
+
+func (m *Matrix) Reset()         { *m = Matrix{} }
+func (m *Matrix) String() string { return proto.CompactTextString(m) }
+func (*Matrix) ProtoMessage()    {}
+func (*Matrix) Descriptor() ([]byte, []int) {
+	return fileDescriptor_d3287f4c1e120e43, []int{0}
+}
+
+func (m *Matrix) XXX_Unmarshal(b []byte) error {
+	return xxx_messageInfo_Matrix.Unmarshal(m, b)
+}
+func (m *Matrix) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+	return xxx_messageInfo_Matrix.Marshal(b, m, deterministic)
+}
+func (m *Matrix) XXX_Merge(src proto.Message) {
+	xxx_messageInfo_Matrix.Merge(m, src)
+}
+func (m *Matrix) XXX_Size() int {
+	return xxx_messageInfo_Matrix.Size(m)
+}
+func (m *Matrix) XXX_DiscardUnknown() {
+	xxx_messageInfo_Matrix.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Matrix proto.InternalMessageInfo
+
+func (m *Matrix) GetData() []float64 {
+	if m != nil {
+		return m.Data
+	}
+	return nil
+}
+
+type Result struct {
+	ResultCharacter      uint32   `protobuf:"varint,1,opt,name=resultCharacter,proto3" json:"resultCharacter,omitempty"`
+	XXX_NoUnkeyedLiteral struct{} `json:"-"`
+	XXX_unrecognized     []byte   `json:"-"`
+	XXX_sizecache        int32    `json:"-"`
+}
+
+func (m *Result) Reset()         { *m = Result{} }
+func (m *Result) String() string { return proto.CompactTextString(m) }
+func (*Result) ProtoMessage()    {}
+func (*Result) Descriptor() ([]byte, []int) {
+	return fileDescriptor_d3287f4c1e120e43, []int{1}
+}
+
+func (m *Result) XXX_Unmarshal(b []byte) error {
+	return xxx_messageInfo_Result.Unmarshal(m, b)
+}
+func (m *Result) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+	return xxx_messageInfo_Result.Marshal(b, m, deterministic)
+}
+func (m *Result) XXX_Merge(src proto.Message) {
+	xxx_messageInfo_Result.Merge(m, src)
+}
+func (m *Result) XXX_Size() int {
+	return xxx_messageInfo_Result.Size(m)
+}
+func (m *Result) XXX_DiscardUnknown() {
+	xxx_messageInfo_Result.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Result proto.InternalMessageInfo
+
+func (m *Result) GetResultCharacter() uint32 {
+	if m != nil {
+		return m.ResultCharacter
+	}
+	return 0
+}
+
+type NeuralNetworkRaw struct {
+	Data                 []byte   `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"`
+	XXX_NoUnkeyedLiteral struct{} `json:"-"`
+	XXX_unrecognized     []byte   `json:"-"`
+	XXX_sizecache        int32    `json:"-"`
+}
+
+func (m *NeuralNetworkRaw) Reset()         { *m = NeuralNetworkRaw{} }
+func (m *NeuralNetworkRaw) String() string { return proto.CompactTextString(m) }
+func (*NeuralNetworkRaw) ProtoMessage()    {}
+func (*NeuralNetworkRaw) Descriptor() ([]byte, []int) {
+	return fileDescriptor_d3287f4c1e120e43, []int{2}
+}
+
+func (m *NeuralNetworkRaw) XXX_Unmarshal(b []byte) error {
+	return xxx_messageInfo_NeuralNetworkRaw.Unmarshal(m, b)
+}
+func (m *NeuralNetworkRaw) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+	return xxx_messageInfo_NeuralNetworkRaw.Marshal(b, m, deterministic)
+}
+func (m *NeuralNetworkRaw) XXX_Merge(src proto.Message) {
+	xxx_messageInfo_NeuralNetworkRaw.Merge(m, src)
+}
+func (m *NeuralNetworkRaw) XXX_Size() int {
+	return xxx_messageInfo_NeuralNetworkRaw.Size(m)
+}
+func (m *NeuralNetworkRaw) XXX_DiscardUnknown() {
+	xxx_messageInfo_NeuralNetworkRaw.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_NeuralNetworkRaw proto.InternalMessageInfo
+
+func (m *NeuralNetworkRaw) GetData() []byte {
+	if m != nil {
+		return m.Data
+	}
+	return nil
+}
+
+type None struct {
+	XXX_NoUnkeyedLiteral struct{} `json:"-"`
+	XXX_unrecognized     []byte   `json:"-"`
+	XXX_sizecache        int32    `json:"-"`
+}
+
+func (m *None) Reset()         { *m = None{} }
+func (m *None) String() string { return proto.CompactTextString(m) }
+func (*None) ProtoMessage()    {}
+func (*None) Descriptor() ([]byte, []int) {
+	return fileDescriptor_d3287f4c1e120e43, []int{3}
+}
+
+func (m *None) XXX_Unmarshal(b []byte) error {
+	return xxx_messageInfo_None.Unmarshal(m, b)
+}
+func (m *None) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+	return xxx_messageInfo_None.Marshal(b, m, deterministic)
+}
+func (m *None) XXX_Merge(src proto.Message) {
+	xxx_messageInfo_None.Merge(m, src)
+}
+func (m *None) XXX_Size() int {
+	return xxx_messageInfo_None.Size(m)
+}
+func (m *None) XXX_DiscardUnknown() {
+	xxx_messageInfo_None.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_None proto.InternalMessageInfo
+
+func init() {
+	proto.RegisterType((*Matrix)(nil), "handwriting.Matrix")
+	proto.RegisterType((*Result)(nil), "handwriting.Result")
+	proto.RegisterType((*NeuralNetworkRaw)(nil), "handwriting.NeuralNetworkRaw")
+	proto.RegisterType((*None)(nil), "handwriting.None")
+}
+
+func init() { proto.RegisterFile("handwriting.proto", fileDescriptor_d3287f4c1e120e43) }
+
+var fileDescriptor_d3287f4c1e120e43 = []byte{
+	// 213 bytes of a gzipped FileDescriptorProto
+	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0xcc, 0x48, 0xcc, 0x4b,
+	0x29, 0x2f, 0xca, 0x2c, 0xc9, 0xcc, 0x4b, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x46,
+	0x12, 0x52, 0x92, 0xe1, 0x62, 0xf3, 0x4d, 0x2c, 0x29, 0xca, 0xac, 0x10, 0x12, 0xe2, 0x62, 0x49,
+	0x49, 0x2c, 0x49, 0x94, 0x60, 0x54, 0x60, 0xd6, 0x60, 0x0c, 0x02, 0xb3, 0x95, 0x8c, 0xb8, 0xd8,
+	0x82, 0x52, 0x8b, 0x4b, 0x73, 0x4a, 0x84, 0x34, 0xb8, 0xf8, 0x8b, 0xc0, 0x2c, 0xe7, 0x8c, 0xc4,
+	0xa2, 0xc4, 0xe4, 0x92, 0xd4, 0x22, 0x09, 0x46, 0x05, 0x46, 0x0d, 0xde, 0x20, 0x74, 0x61, 0x25,
+	0x35, 0x2e, 0x01, 0xbf, 0xd4, 0xd2, 0xa2, 0xc4, 0x1c, 0xbf, 0xd4, 0x92, 0xf2, 0xfc, 0xa2, 0xec,
+	0xa0, 0xc4, 0x72, 0x24, 0xb3, 0x19, 0x35, 0x78, 0xa0, 0x66, 0xb3, 0x71, 0xb1, 0xf8, 0xe5, 0xe7,
+	0xa5, 0x1a, 0x4d, 0x62, 0xe4, 0xe2, 0xf6, 0x40, 0xb8, 0x48, 0xc8, 0x9c, 0x8b, 0xb3, 0x28, 0x35,
+	0x39, 0x3f, 0x3d, 0x2f, 0xb3, 0x2a, 0x55, 0x48, 0x58, 0x0f, 0xd9, 0xfd, 0x10, 0x97, 0x4a, 0xa1,
+	0x0a, 0x42, 0x1c, 0xa8, 0xc4, 0x20, 0xe4, 0xc5, 0x25, 0x52, 0x9c, 0x5a, 0x82, 0x62, 0xb7, 0x4b,
+	0x62, 0x49, 0xa2, 0x90, 0x2c, 0x8a, 0x72, 0x74, 0xb7, 0x49, 0x09, 0xa2, 0x4a, 0xe7, 0xe7, 0xa5,
+	0x2a, 0x31, 0x24, 0xb1, 0x81, 0x83, 0xca, 0x18, 0x10, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x34, 0x3f,
+	0x3f, 0x3f, 0x01, 0x00, 0x00,
+}
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ context.Context
+var _ grpc.ClientConn
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the grpc package it is being compiled against.
+const _ = grpc.SupportPackageIsVersion4
+
+// HandwritingClient is the client API for Handwriting service.
+//
+// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
+type HandwritingClient interface {
+	Recognize(ctx context.Context, in *Matrix, opts ...grpc.CallOption) (*Result, error)
+	SetNeuralNetworkData(ctx context.Context, in *NeuralNetworkRaw, opts ...grpc.CallOption) (*None, error)
+}
+
+type handwritingClient struct {
+	cc *grpc.ClientConn
+}
+
+func NewHandwritingClient(cc *grpc.ClientConn) HandwritingClient {
+	return &handwritingClient{cc}
+}
+
+func (c *handwritingClient) Recognize(ctx context.Context, in *Matrix, opts ...grpc.CallOption) (*Result, error) {
+	out := new(Result)
+	err := c.cc.Invoke(ctx, "/handwriting.Handwriting/recognize", in, out, opts...)
+	if err != nil {
+		return nil, err
+	}
+	return out, nil
+}
+
+func (c *handwritingClient) SetNeuralNetworkData(ctx context.Context, in *NeuralNetworkRaw, opts ...grpc.CallOption) (*None, error) {
+	out := new(None)
+	err := c.cc.Invoke(ctx, "/handwriting.Handwriting/setNeuralNetworkData", in, out, opts...)
+	if err != nil {
+		return nil, err
+	}
+	return out, nil
+}
+
+// HandwritingServer is the server API for Handwriting service.
+type HandwritingServer interface {
+	Recognize(context.Context, *Matrix) (*Result, error)
+	SetNeuralNetworkData(context.Context, *NeuralNetworkRaw) (*None, error)
+}
+
+// UnimplementedHandwritingServer can be embedded to have forward compatible implementations.
+type UnimplementedHandwritingServer struct {
+}
+
+func (*UnimplementedHandwritingServer) Recognize(ctx context.Context, req *Matrix) (*Result, error) {
+	return nil, status.Errorf(codes.Unimplemented, "method Recognize not implemented")
+}
+func (*UnimplementedHandwritingServer) SetNeuralNetworkData(ctx context.Context, req *NeuralNetworkRaw) (*None, error) {
+	return nil, status.Errorf(codes.Unimplemented, "method SetNeuralNetworkData not implemented")
+}
+
+func RegisterHandwritingServer(s *grpc.Server, srv HandwritingServer) {
+	s.RegisterService(&_Handwriting_serviceDesc, srv)
+}
+
+func _Handwriting_Recognize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+	in := new(Matrix)
+	if err := dec(in); err != nil {
+		return nil, err
+	}
+	if interceptor == nil {
+		return srv.(HandwritingServer).Recognize(ctx, in)
+	}
+	info := &grpc.UnaryServerInfo{
+		Server:     srv,
+		FullMethod: "/handwriting.Handwriting/Recognize",
+	}
+	handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+		return srv.(HandwritingServer).Recognize(ctx, req.(*Matrix))
+	}
+	return interceptor(ctx, in, info, handler)
+}
+
+func _Handwriting_SetNeuralNetworkData_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+	in := new(NeuralNetworkRaw)
+	if err := dec(in); err != nil {
+		return nil, err
+	}
+	if interceptor == nil {
+		return srv.(HandwritingServer).SetNeuralNetworkData(ctx, in)
+	}
+	info := &grpc.UnaryServerInfo{
+		Server:     srv,
+		FullMethod: "/handwriting.Handwriting/SetNeuralNetworkData",
+	}
+	handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+		return srv.(HandwritingServer).SetNeuralNetworkData(ctx, req.(*NeuralNetworkRaw))
+	}
+	return interceptor(ctx, in, info, handler)
+}
+
+var _Handwriting_serviceDesc = grpc.ServiceDesc{
+	ServiceName: "handwriting.Handwriting",
+	HandlerType: (*HandwritingServer)(nil),
+	Methods: []grpc.MethodDesc{
+		{
+			MethodName: "recognize",
+			Handler:    _Handwriting_Recognize_Handler,
+		},
+		{
+			MethodName: "setNeuralNetworkData",
+			Handler:    _Handwriting_SetNeuralNetworkData_Handler,
+		},
+	},
+	Streams:  []grpc.StreamDesc{},
+	Metadata: "handwriting.proto",
+}

+ 8 - 1
handwriting/handwritingui/handwritingengine.cpp

@@ -53,7 +53,7 @@ HandwritingEngine::HandwritingEngine(QObject *parent) : QObject(parent)
   , m_client(new handwriting::HandwritingClient)
 {
 
-    auto chan = std::shared_ptr<QtProtobuf::QGrpcHttp2Channel>(new QtProtobuf::QGrpcHttp2Channel(QUrl("http://localhost:65002"), QtProtobuf::InsecureCredentials()|NoneCredencials()));
+    auto chan = std::shared_ptr<QtProtobuf::QGrpcHttp2Channel>(new QtProtobuf::QGrpcHttp2Channel(QUrl("http://localhost:65001"), QtProtobuf::InsecureCredentials()|NoneCredencials()));
     m_client->attachChannel(chan);
 
     m_matrix.setData(emptyData);
@@ -69,6 +69,7 @@ void HandwritingEngine::recognize()
 
 void HandwritingEngine::setNeuralNetworkData(const QString &networkDataPath)
 {
+    qDebug() << "HandwritingEngine::setNeuralNetworkData";
     QFile dataFile(QUrl(networkDataPath).toLocalFile());
     if (!dataFile.open(QFile::ReadOnly)) {
         qCritical() << "Could not open" << QUrl(networkDataPath).toLocalFile();
@@ -77,3 +78,9 @@ void HandwritingEngine::setNeuralNetworkData(const QString &networkDataPath)
 
     m_client->setNeuralNetworkData({dataFile.readAll()});
 }
+
+void HandwritingEngine::updateValue(int index, double value)
+{
+    m_matrix.data()[index] = value;
+}
+

+ 1 - 0
handwriting/handwritingui/handwritingengine.h

@@ -24,6 +24,7 @@ public:
 
     Q_INVOKABLE void recognize();
     Q_INVOKABLE void setNeuralNetworkData(const QString &networkDataPath);
+    Q_INVOKABLE void updateValue(int index, double value);
 
     int result() const
     {

+ 10 - 4
handwriting/handwritingui/main.qml

@@ -53,8 +53,8 @@ ApplicationWindow {
                 width: tileSize
                 height: tileSize
                 color: resetColor
-                x: tileSize*Math.floor(model.index/tileCount)
-                y: tileSize*(model.index%tileCount)
+                x: tileSize*(model.index%tileCount)
+                y: tileSize*Math.floor(model.index/tileCount)
                 Connections {
                     target: drawingArea
                     onPositionChanged: {
@@ -72,7 +72,7 @@ ApplicationWindow {
                             }
 
                             tile.color = Qt.rgba(newColor, newColor, newColor, 1.0)
-                            hwengine.matrix[model.index] = newColor
+                            hwengine.updateValue(model.index, newValue)
                         }
                     }
                 }
@@ -81,7 +81,7 @@ ApplicationWindow {
                     target: clearButton
                     onClicked: {
                         tile.color = resetColor
-                        hwengine.matrix[model.index] = 0.0
+                        hwengine.updateValue(model.index, 0.0)
                     }
                 }
             }
@@ -115,6 +115,9 @@ ApplicationWindow {
             Button {
                 id: uploadButton
                 text: "Upload"
+                onClicked: {
+                    hwengine.setNeuralNetworkData(pathToNeuralNetwork.text)
+                }
             }
         }
         Button {
@@ -124,6 +127,9 @@ ApplicationWindow {
         Button {
             id: recognizeButton
             text: "Recognize"
+            onClicked: {
+                hwengine.recognize()
+            }
         }
     }
 

+ 10 - 1
handwriting/main.go

@@ -1 +1,10 @@
- 
+package main
+
+import (
+	handwriting "./handwriting"
+)
+
+func main() {
+	hws := handwriting.NewHandwritingService()
+	hws.Run()
+}