Ver código fonte

Make simple activation indication

Alexey Edelev 5 anos atrás
pai
commit
319a065e18
10 arquivos alterados com 179 adições e 36 exclusões
  1. 63 0
      gui/abstractdense.h
  2. 5 9
      gui/dense.cpp
  3. 6 14
      gui/dense.h
  4. 2 0
      gui/main.cpp
  5. 12 2
      gui/main.qml
  6. 22 1
      gui/valueindicator.cpp
  7. 13 2
      gui/valueindicator.h
  8. 40 6
      gui/visualizermodel.cpp
  9. 14 2
      gui/visualizermodel.h
  10. 2 0
      neuralnetwork/main.go

+ 63 - 0
gui/abstractdense.h

@@ -0,0 +1,63 @@
+/*
+ * MIT License
+ *
+ * Copyright (c) 2019 Alexey Edelev <semlanik@gmail.com>
+ *
+ * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and
+ * to permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
+ * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+ * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ * DEALINGS IN THE SOFTWARE.
+ */
+
+#pragma once
+
+template <typename T>
+class AbstractDense {
+public:
+    AbstractDense() = default;
+    AbstractDense(int rows, int columns, const T &data) :
+        m_rows(rows)
+      , m_columns(columns)
+      , m_data(data) {}
+
+    void setDimentions(int rows, int columns) {
+        m_rows = rows;
+        m_columns = columns;
+    }
+
+    void setData(const T &data) {
+        m_data = data;
+    }
+
+    int rows() const {
+        return m_rows;
+    }
+
+    int columns() const {
+        return m_columns;
+    }
+
+    template<typename R>
+    R value(int row, int column) const {
+        return m_data[(m_columns - 1) * row + column + row];
+    }
+
+protected:
+    int m_rows;
+    int m_columns;
+    T m_data;
+};

+ 5 - 9
gui/dense.cpp

@@ -41,18 +41,14 @@
 //        ...
 //        [nrows-1,0] ... [nrows-1,ncols-1]
 
-Dense::Dense(const QByteArray &data)
+Dense::Dense(const QByteArray &data) : AbstractDense(*(int64_t *)(data.data() + 8), *(int64_t *)(data.data() + 16), data)
 {
-    m_data = data;
-    m_rows = *(int64_t *)(data.data() + 8);
-    m_columns = *(int64_t *)(data.data() + 16);
-
-    Q_ASSERT(m_rows * m_columns * sizeof(double) + 40 == data.size());
 }
 
-double Dense::value(int row, int column)
+template<>
+template<>
+double AbstractDense<QByteArray>::value<double>(int row, int column) const
 {
-    char *dataPtr = m_data.data() + 40 + (m_rows * row + m_columns * column) * sizeof(double);
+    const char *dataPtr = m_data.data() + 40 + ((m_columns - 1) * row + column + row) * sizeof(double);
     return *(double *)dataPtr;
 }
-

+ 6 - 14
gui/dense.h

@@ -26,22 +26,14 @@
 #pragma once
 
 #include <QByteArray>
+#include "abstractdense.h"
 
-class Dense
+class Dense : public AbstractDense<QByteArray>
 {
 public:
     Dense(const QByteArray &data);
-    double value(int row, int column);
-
-    int rows() {
-        return m_rows;
-    }
-
-    int columns() {
-        return m_columns;
-    }
-private:
-    int m_rows;
-    int m_columns;
-    QByteArray m_data;
 };
+
+template<>
+template<>
+double AbstractDense<QByteArray>::value<double>(int row, int column) const;

+ 2 - 0
gui/main.cpp

@@ -32,6 +32,7 @@
 #include "qgrpchttp2channel.h"
 #include "insecurecredentials.h"
 #include "visualizermodel.h"
+#include "valueindicator.h"
 
 class NoneCredencials : public QtProtobuf::CallCredentials
 {
@@ -44,6 +45,7 @@ int main(int argc, char *argv[])
 {
     QGuiApplication app(argc, argv);
 
+    qmlRegisterUncreatableType<ValueIndicator>("NeuralNetworkUi", 0, 1, "ValueIndicator", "");
     std::shared_ptr<remotecontrol::RemoteControlClient> client(new remotecontrol::RemoteControlClient);
     auto chan = std::shared_ptr<QtProtobuf::QGrpcHttp2Channel>(new QtProtobuf::QGrpcHttp2Channel(QUrl("http://localhost:65001"), QtProtobuf::InsecureCredentials()|NoneCredencials()));
     client->attachChannel(chan);

+ 12 - 2
gui/main.qml

@@ -26,6 +26,7 @@
 import QtQuick 2.11
 import QtQuick.Window 2.11
 import QtQuick.Controls 1.4
+import NeuralNetworkUi 0.1
 
 ApplicationWindow {
     id: root
@@ -64,9 +65,18 @@ ApplicationWindow {
                         model: layerSize
                         delegate: Rectangle {
                             id: neuron
+                            property ValueIndicator activation: visualizerModel.activation(layerIndex, model.index)
                             width: 30
                             height: 30
                             radius: 15
+                            color: {
+                                if (activation.value > 0) {
+                                    var alpha = activation.value
+                                    Qt.rgba(0, 1, 0, alpha)
+                                } else {
+                                    "#ffffffff"
+                                }
+                            }
                         }
                     }
                 }
@@ -89,7 +99,7 @@ ApplicationWindow {
                         var coordPrev = layerRepeater.itemAt(i - 1).mapToItem(root.contentItem, neuronPrev.x + neuronPrev.width/2, neuronPrev.y + neuronPrev.height/2)
                         var angle =  Math.atan2(coordPrev.y - coord.y, coordPrev.x - coord.x) * 180 / Math.PI
                         var length = Math.sqrt(Math.pow(coordPrev.x - coord.x, 2) + Math.pow(coordPrev.y - coord.y, 2))
-                        connection.createObject(bottomLayer, {x: coord.x, y: coord.y, width: length, angle: angle})
+                        connection.createObject(bottomLayer, {x: coord.x, y: coord.y, width: length, angle: angle })
                     }
                 }
             }
@@ -106,7 +116,7 @@ ApplicationWindow {
             }
 
             height: 1
-            color: "#5500ff00"
+            color: "#55ffffff"
         }
     }
 }

+ 22 - 1
gui/valueindicator.cpp

@@ -25,7 +25,28 @@
 
 #include "valueindicator.h"
 
-ValueIndicator::ValueIndicator()
+ValueIndicator::ValueIndicator() : QObject()
+  , m_value(0)
 {
 
 }
+
+
+ValueIndicatorDense::ValueIndicatorDense(int rows, int columns, const QList<ValueIndicator*>& data) : AbstractDense(rows, columns, data)
+{
+
+}
+
+void ValueIndicatorDense::updateValues(const Dense& dense)
+{
+    for(int i = 0; i < dense.rows(); i++) {
+        for(int j = 0; j < dense.columns(); j++) {
+            value<ValueIndicator*>(i,j)->setValue(dense.value<double>(i, j));
+        }
+    }
+}
+
+ValueIndicatorDense::~ValueIndicatorDense()
+{
+    qDeleteAll(m_data);
+}

+ 13 - 2
gui/valueindicator.h

@@ -26,9 +26,12 @@
 #pragma once
 
 #include <QObject>
+#include "abstractdense.h"
+#include "dense.h"
 
-class ValueIndicator
+class ValueIndicator : public QObject
 {
+    Q_OBJECT
     Q_PROPERTY(qreal value READ value WRITE setValue NOTIFY valueChanged)
 public:
     ValueIndicator();
@@ -39,7 +42,7 @@ public:
 public slots:
     void setValue(qreal value)
     {
-        qWarning("Floating point comparison needs context sanity check");
+//        qWarning("Floating point comparison needs context sanity check");
         if (qFuzzyCompare(m_value, value))
             return;
 
@@ -51,3 +54,11 @@ signals:
 private:
     qreal m_value;
 };
+
+class ValueIndicatorDense : public AbstractDense<QList<ValueIndicator*>> {
+public:
+    ValueIndicatorDense() = default;
+    virtual ~ValueIndicatorDense();
+    ValueIndicatorDense(int rows, int columns, const QList<ValueIndicator*>& data);
+    void updateValues(const Dense& dense);
+};

+ 40 - 6
gui/visualizermodel.cpp

@@ -38,23 +38,57 @@ VisualizerModel::VisualizerModel(std::shared_ptr<RemoteControlClient> &client, Q
   , m_client(client)
 {
     m_client->getConfiguration({}, this, [this](QGrpcAsyncReply *reply) {
+        qDeleteAll(m_layers);
         m_networkConfig = reply->read<Configuration>();
+        for(int i = 0; i < m_networkConfig.sizes().size(); i++) {
+            m_layers.append(new NetworkLayerState);
+            m_layers.last()->m_activations.setDimentions(m_networkConfig.sizes()[i], 1);
+            QList<ValueIndicator*> data;
+            for (int k = 0; k < m_networkConfig.sizes()[i]; k++) {
+                data.append(new ValueIndicator);
+            }
+            m_layers.last()->m_activations.setData(data);
+        }
         sizesChanged();
     });
 
-    QObject::connect(client.get(), &remotecontrol::RemoteControlClient::ActivationsUpdated, [](const remotecontrol::LayerMatrix &activations) {
+    QObject::connect(client.get(), &remotecontrol::RemoteControlClient::ActivationsUpdated, [this](const remotecontrol::LayerMatrix &activations) {
+        if (m_layers.isEmpty()) {
+            return;
+        }
         Dense dense(activations.matrix().matrix());
-        qDebug() << "ActivationsUpdated:" << dense.rows() << dense.columns();
+        m_layers[activations.layer()]->m_activations.updateValues(Dense(activations.matrix().matrix()));
+//        qDebug() << "ActivationsUpdated:" << dense.rows() << dense.columns() << activations.layer();
     });
-    QObject::connect(client.get(), &remotecontrol::RemoteControlClient::BiasesUpdated, [](const remotecontrol::LayerMatrix &biases) {
+    QObject::connect(client.get(), &remotecontrol::RemoteControlClient::BiasesUpdated, [this](const remotecontrol::LayerMatrix &biases) {
+        if (m_layers.isEmpty()) {
+            return;
+        }
         Dense dense(biases.matrix().matrix());
-        qDebug() << "BiasesUpdated:" << dense.rows() << dense.columns();
+//        qDebug() << "BiasesUpdated:" << dense.rows() << dense.columns();
     });
-    QObject::connect(client.get(), &remotecontrol::RemoteControlClient::WeightsUpdated, [](const remotecontrol::LayerMatrix &weights) {
+    QObject::connect(client.get(), &remotecontrol::RemoteControlClient::WeightsUpdated, [this](const remotecontrol::LayerMatrix &weights) {
+        if (m_layers.isEmpty()) {
+            return;
+        }
         Dense dense(weights.matrix().matrix());
-        qDebug() << "WeightsUpdated:" << dense.rows() << dense.columns();
+//        qDebug() << "WeightsUpdated:" << dense.rows() << dense.columns();
     });
     client->subscribeActivationsUpdates({});
     client->subscribeBiasesUpdates({});
     client->subscribeWeightsUpdates({});
 }
+
+ValueIndicator *VisualizerModel::activation(int layer, int row)
+{
+    ValueIndicator* indicator = m_layers[layer]->m_activations.value<ValueIndicator*>(row, 0);
+    QQmlEngine::setObjectOwnership(indicator, QQmlEngine::CppOwnership);
+    return indicator;
+}
+
+ValueIndicator *VisualizerModel::weight(int layer, int row, int column)
+{
+    ValueIndicator* indicator = m_layers[layer]->m_weights.value<ValueIndicator*>(row, column);
+    QQmlEngine::setObjectOwnership(indicator, QQmlEngine::CppOwnership);
+    return indicator;
+}

+ 14 - 2
gui/visualizermodel.h

@@ -27,7 +27,18 @@
 
 #include <QObject>
 #include <QGenericMatrix>
+
 #include "remotecontrolclient.h"
+#include "valueindicator.h"
+#include "abstractdense.h"
+
+class ValueIndicator;
+
+struct NetworkLayerState {
+    ValueIndicatorDense m_activations;
+    ValueIndicatorDense m_biases;
+    ValueIndicatorDense m_weights;
+};
 
 class VisualizerModel : public QObject
 {
@@ -40,8 +51,8 @@ public:
         return m_networkConfig.sizes();
     }
 
-//    ValueIndicator *activation(int layer, int row);
-//    ValueIndicator *weight(int layer, int row, int column);
+    Q_INVOKABLE ValueIndicator *activation(int layer, int row);
+    Q_INVOKABLE ValueIndicator *weight(int layer, int row, int column);
 
 signals:
     void sizesChanged();
@@ -49,4 +60,5 @@ signals:
 private:
     std::shared_ptr<remotecontrol::RemoteControlClient> &m_client;
     remotecontrol::Configuration m_networkConfig;
+    QList<NetworkLayerState*> m_layers;
 };

+ 2 - 0
neuralnetwork/main.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"log"
 	"os"
+	"time"
 
 	neuralnetwork "./neuralnetworkbase"
 	remotecontrol "./remotecontrol"
@@ -66,6 +67,7 @@ func main() {
 		for teacher.NextValidator() {
 			dataSet, expect := teacher.GetValidator()
 			index, _ := nn.Predict(dataSet)
+			time.Sleep(400 * time.Millisecond)
 			if expect.At(index, 0) != 1.0 {
 				failCount++
 				// fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))