|
@@ -40,23 +40,41 @@ import (
|
|
|
)
|
|
|
|
|
|
type RemoteControl struct {
|
|
|
- nn *neuralnetworkbase.NeuralNetwork
|
|
|
+ nn *neuralnetworkbase.NeuralNetwork
|
|
|
+ activationsQueue chan *LayerMatrix
|
|
|
+ biasesQueue chan *LayerMatrix
|
|
|
+ weightsQueue chan *LayerMatrix
|
|
|
}
|
|
|
|
|
|
func (rw *RemoteControl) Init(nn *neuralnetworkbase.NeuralNetwork) {
|
|
|
rw.nn = nn
|
|
|
+ rw.activationsQueue = make(chan *LayerMatrix, 10)
|
|
|
+ rw.biasesQueue = make(chan *LayerMatrix, 10)
|
|
|
+ rw.weightsQueue = make(chan *LayerMatrix, 10)
|
|
|
}
|
|
|
|
|
|
func (rw *RemoteControl) UpdateActivations(l int, a *mat.Dense) {
|
|
|
- // matrix := NewLayerMatrix(l, a, LayerMatrix_Activations)
|
|
|
+ matrix := NewLayerMatrix(l, a, LayerMatrix_Activations)
|
|
|
+ select {
|
|
|
+ case rw.activationsQueue <- matrix:
|
|
|
+ default:
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func (rw *RemoteControl) UpdateBiases(l int, biases *mat.Dense) {
|
|
|
- // matrix := NewLayerMatrix(l, biases, LayerMatrix_Biases)
|
|
|
+ matrix := NewLayerMatrix(l, biases, LayerMatrix_Biases)
|
|
|
+ select {
|
|
|
+ case rw.biasesQueue <- matrix:
|
|
|
+ default:
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func (rw *RemoteControl) UpdateWeights(l int, weights *mat.Dense) {
|
|
|
- // matrix := NewLayerMatrix(l, weights, LayerMatrix_Weights)
|
|
|
+ matrix := NewLayerMatrix(l, weights, LayerMatrix_Weights)
|
|
|
+ select {
|
|
|
+ case rw.weightsQueue <- matrix:
|
|
|
+ default:
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func NewLayerMatrix(l int, dense *mat.Dense, contentType LayerMatrix_ContentType) (matrix *LayerMatrix) {
|
|
@@ -80,16 +98,43 @@ func (rw *RemoteControl) GetConfiguration(context.Context, *None) (*Configuratio
|
|
|
return nil, status.Error(codes.Unimplemented, "Not implemented")
|
|
|
}
|
|
|
|
|
|
-func (rw *RemoteControl) Activations(*None, RemoteControl_ActivationsServer) error {
|
|
|
- return status.Error(codes.Unimplemented, "Not implemented")
|
|
|
+func (rw *RemoteControl) Activations(_ *None, srv RemoteControl_ActivationsServer) error {
|
|
|
+ ctx := srv.Context()
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ return ctx.Err()
|
|
|
+ default:
|
|
|
+ }
|
|
|
+ msg := <-rw.activationsQueue
|
|
|
+ srv.Send(msg)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
-func (rw *RemoteControl) Biases(*None, RemoteControl_BiasesServer) error {
|
|
|
- return status.Error(codes.Unimplemented, "Not implemented")
|
|
|
+func (rw *RemoteControl) Biases(_ *None, srv RemoteControl_BiasesServer) error {
|
|
|
+ ctx := srv.Context()
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ return ctx.Err()
|
|
|
+ default:
|
|
|
+ }
|
|
|
+ msg := <-rw.biasesQueue
|
|
|
+ srv.Send(msg)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
-func (rw *RemoteControl) Weights(*None, RemoteControl_WeightsServer) error {
|
|
|
- return status.Error(codes.Unimplemented, "Not implemented")
|
|
|
+func (rw *RemoteControl) Weights(_ *None, srv RemoteControl_WeightsServer) error {
|
|
|
+ ctx := srv.Context()
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ return ctx.Err()
|
|
|
+ default:
|
|
|
+ }
|
|
|
+ msg := <-rw.weightsQueue
|
|
|
+ srv.Send(msg)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func (rw *RemoteControl) Predict(context.Context, *Matrix) (*Matrix, error) {
|