Alexey Edelev 5 роки тому
батько
коміт
937dd91e1d

+ 1 - 1
gui/CMakeLists.txt

@@ -19,5 +19,5 @@ generate_qtprotobuf(TARGET NeuralNetworkUi PROTO_FILES ${PROTO_FILES})
 set(CMAKE_AUTOMOC ON)
 set(CMAKE_AUTORCC ON)
 
-add_executable(NeuralNetworkUi main.cpp qml.qrc valueindicator.cpp visualizermodel.cpp dense.cpp)
+add_executable(NeuralNetworkUi main.cpp qml.qrc valueindicator.cpp visualizermodel.cpp dense.cpp layertrigger.cpp)
 target_link_libraries(NeuralNetworkUi Qt5::Core Qt5::Gui Qt5::Qml Qt5::Quick QtProtobufProject::QtProtobuf QtProtobufProject::QtGrpc ${QtProtobuf_GENERATED})

+ 6 - 0
gui/dense.cpp

@@ -45,6 +45,12 @@ Dense::Dense(const QByteArray &data) : AbstractDense(*(int64_t *)(data.data() +
 {
 }
 
+double Dense::rawValue(int i) const
+{
+    return *(double *)(m_data.data() + 40 + i * sizeof(double));
+}
+
+
 template<>
 template<>
 double AbstractDense<QByteArray>::value<double>(int row, int column) const

+ 2 - 0
gui/dense.h

@@ -32,6 +32,8 @@ class Dense : public AbstractDense<QByteArray>
 {
 public:
     Dense(const QByteArray &data);
+
+    double rawValue(int i) const;
 };
 
 template<>

+ 6 - 0
gui/layertrigger.cpp

@@ -0,0 +1,6 @@
+#include "layertrigger.h"
+
+LayerTrigger::LayerTrigger(QObject *parent) : QObject(parent)
+{
+
+}

+ 16 - 0
gui/layertrigger.h

@@ -0,0 +1,16 @@
+#ifndef LAYERTRIGGER_H
+#define LAYERTRIGGER_H
+
+#include <QObject>
+
+class LayerTrigger : public QObject
+{
+    Q_OBJECT
+public:
+    explicit LayerTrigger(QObject *parent = nullptr);
+
+signals:
+    void updateLayer();
+};
+
+#endif // LAYERTRIGGER_H

+ 1 - 0
gui/main.cpp

@@ -46,6 +46,7 @@ int main(int argc, char *argv[])
     QGuiApplication app(argc, argv);
 
     qmlRegisterUncreatableType<ValueIndicator>("NeuralNetworkUi", 0, 1, "ValueIndicator", "");
+    qmlRegisterUncreatableType<LayerTrigger>("NeuralNetworkUi", 0, 1, "LayerTrigger", "");
     std::shared_ptr<remotecontrol::RemoteControlClient> client(new remotecontrol::RemoteControlClient);
     auto chan = std::shared_ptr<QtProtobuf::QGrpcHttp2Channel>(new QtProtobuf::QGrpcHttp2Channel(QUrl("http://localhost:65001"), QtProtobuf::InsecureCredentials()|NoneCredencials()));
     client->attachChannel(chan);

+ 15 - 9
gui/main.qml

@@ -74,9 +74,13 @@ ApplicationWindow {
 
                                 anchors.fill: parent
                                 radius: 15
-                                color: {
-                                    var alpha = activation.value
-                                    Qt.rgba(0, 1, 0, alpha)
+                                color: "transparent"
+                                function updateColor() {
+                                    var alpha = activation.value;
+                                    neuron.color = Qt.rgba(0, 1, 0, Math.max(0.08, alpha))
+                                }
+                                Component.onCompleted: {
+                                    visualizerModel.activationTrigger(layerIndex).updateLayer.connect(neuron.updateColor);
                                 }
                             }
                         }
@@ -101,13 +105,14 @@ ApplicationWindow {
                         var coordPrev = layerRepeater.itemAt(i - 1).mapToItem(root.contentItem, neuronPrev.x + neuronPrev.width/2, neuronPrev.y + neuronPrev.height/2)
                         var angle =  Math.atan2(coordPrev.y - coord.y, coordPrev.x - coord.x) * 180 / Math.PI
                         var length = Math.sqrt(Math.pow(coordPrev.x - coord.x, 2) + Math.pow(coordPrev.y - coord.y, 2))
-                        connection.createObject(bottomLayer, {
+                        var obj = connectionComponent.createObject(bottomLayer, {
                                                     x: coord.x,
                                                     y: coord.y,
                                                     width: length,
                                                     angle: angle,
                                                     weight: visualizerModel.weight(i, j, k),
                                                 })
+                        visualizerModel.weightTrigger(i).updateLayer.connect(obj.updateColor);
                     }
                 }
             }
@@ -115,8 +120,9 @@ ApplicationWindow {
     }
 
     Component {
-        id: connection
+        id: connectionComponent
         Rectangle {
+            id: connection
             property alias angle: trans.angle
             property ValueIndicator weight: null
 
@@ -124,11 +130,11 @@ ApplicationWindow {
             transform: Rotation {
                 id: trans
             }
-
+            color: "transparent"
             height: 1
-            color: {
-                var color = weight.value
-                Qt.rgba(color, 0, 1.0/color, color > 0 ? 0.5 : 0.0)
+            function updateColor() {
+                var newColor = weight.value;
+                connection.color = Qt.rgba(newColor, 0, 1.0 - newColor, newColor > 0 ? 0.5 : 0.0)
             }
         }
     }

+ 12 - 17
gui/valueindicator.cpp

@@ -30,15 +30,16 @@
 
 #include <QDebug>
 
-ValueIndicator::ValueIndicator() : QObject()
+ValueIndicator::ValueIndicator(ValueIndicatorDense *dense) : QObject()
   , m_value(0)
+  , m_dense(dense)
 {
 
 }
 
 qreal ValueIndicator::value() const
 {
-    return m_value;
+    return (m_value - m_dense->min())/(m_dense->max() - m_dense->min());
 }
 
 
@@ -53,25 +54,19 @@ void ValueIndicatorDense::updateValues(const Dense& dense)
 {
     m_max = std::numeric_limits<double>::min();
     m_min = std::numeric_limits<double>::max();
-    for(int i = 0; i < dense.rows(); i++) {
-        for(int j = 0; j < dense.columns(); j++) {
-            double val = dense.value<double>(i, j);
-            if (val > m_max) {
-                m_max = val;
-            }
 
-            if (val < m_min) {
-                m_min = val;
-            }
+    int i = 0;
+    for (auto value : m_data) {
+        double val = dense.rawValue(i);
+        if (val > m_max) {
+            m_max = val;
         }
-    }
-
-    for(int i = 0; i < dense.rows(); i++) {
-        for(int j = 0; j < dense.columns(); j++) {
-            double val = dense.value<double>(i, j);
 
-            value<ValueIndicator*>(i,j)->setValue((val - m_min)/(m_max - m_min));
+        if (val < m_min) {
+            m_min = val;
         }
+        value->setValue(val);
+        i++;
     }
 }
 

+ 3 - 8
gui/valueindicator.h

@@ -33,24 +33,19 @@ class ValueIndicatorDense;
 class ValueIndicator : public QObject
 {
     Q_OBJECT
-    Q_PROPERTY(qreal value READ value WRITE setValue NOTIFY valueChanged)
+    Q_PROPERTY(qreal value READ value CONSTANT)
 public:
-    ValueIndicator();
+    ValueIndicator(ValueIndicatorDense *dense);
     qreal value() const;
 
 public slots:
     void setValue(qreal value)
     {
-        if (qFuzzyCompare(m_value, value))
-            return;
-
         m_value = value;
-        emit valueChanged(m_value);
     }
-signals:
-    void valueChanged(qreal value);
 private:
     qreal m_value;
+    ValueIndicatorDense *m_dense;
 };
 
 class ValueIndicatorDense : public AbstractDense<QList<ValueIndicator*>> {

+ 33 - 20
gui/visualizermodel.cpp

@@ -42,21 +42,26 @@ VisualizerModel::VisualizerModel(std::shared_ptr<RemoteControlClient> &client, Q
         m_networkConfig = reply->read<Configuration>();
         for(int i = 0; i < m_networkConfig.sizes().size(); i++) {
             m_layers.append(new NetworkLayerState);
-            m_layers.last()->m_activations.setDimentions(m_networkConfig.sizes()[i], 1);
+            auto layerState = m_layers.last();
+            auto currenSize = m_networkConfig.sizes()[i];
+            auto &activations = layerState->m_activations;
+            auto &weights = layerState->m_weights;
+            activations.setDimentions(currenSize, 1);
             QList<ValueIndicator*> data;
-            for (int k = 0; k < m_networkConfig.sizes()[i]; k++) {
-                data.append(new ValueIndicator);
+            for (int k = 0; k < currenSize; k++) {
+                data.append(new ValueIndicator(&activations));
             }
-            m_layers.last()->m_activations.setData(data);
+            activations.setData(data);
 
             if (i != 0) {
-                int tolalItems = m_networkConfig.sizes()[i]*m_networkConfig.sizes()[i - 1];
-                m_layers.last()->m_weights.setDimentions(m_networkConfig.sizes()[i], m_networkConfig.sizes()[i - 1]);
+                auto previousSize = m_networkConfig.sizes()[i - 1];
+                int tolalItems = currenSize * previousSize;
+                weights.setDimentions(currenSize,previousSize);
                 data.clear();
                 for (int k = 0; k < tolalItems; k++) {
-                    data.append(new ValueIndicator);
+                    data.append(new ValueIndicator(&weights));
                 }
-                m_layers.last()->m_weights.setData(data);
+                weights.setData(data);
             }
         }
         sizesChanged();
@@ -66,25 +71,19 @@ VisualizerModel::VisualizerModel(std::shared_ptr<RemoteControlClient> &client, Q
         if (m_layers.isEmpty()) {
             return;
         }
-        Dense dense(activations.matrix().matrix());
-        m_layers[activations.layer()]->m_activations.updateValues(Dense(activations.matrix().matrix()));
-//        qDebug() << "ActivationsUpdated:" << dense.rows() << dense.columns() << activations.layer();
-    });
-    QObject::connect(client.get(), &remotecontrol::RemoteControlClient::BiasesUpdated, [this](const remotecontrol::LayerMatrix &biases) {
-        if (m_layers.isEmpty()) {
-            return;
-        }
-        Dense dense(biases.matrix().matrix());
-//        qDebug() << "BiasesUpdated:" << dense.rows() << dense.columns();
+        auto layer = m_layers[activations.layer()];
+        layer->m_activations.updateValues(Dense(activations.matrix().matrix()));
+        layer->m_actiovationTrigger.updateLayer();
     });
     QObject::connect(client.get(), &remotecontrol::RemoteControlClient::WeightsUpdated, [this](const remotecontrol::LayerMatrix &weights) {
         if (m_layers.isEmpty()) {
             return;
         }
-         m_layers[weights.layer()]->m_weights.updateValues(Dense(weights.matrix().matrix()));
+        auto layer = m_layers[weights.layer()];
+        layer->m_weights.updateValues(Dense(weights.matrix().matrix()));
+        layer->m_weightTrigger.updateLayer();
     });
     client->subscribeActivationsUpdates({});
-    client->subscribeBiasesUpdates({});
     client->subscribeWeightsUpdates({});
 }
 
@@ -101,3 +100,17 @@ ValueIndicator *VisualizerModel::weight(int layer, int row, int column)
     QQmlEngine::setObjectOwnership(indicator, QQmlEngine::CppOwnership);
     return indicator;
 }
+
+LayerTrigger *VisualizerModel::activationTrigger(int layer)
+{
+    LayerTrigger* trigger = &m_layers[layer]->m_actiovationTrigger;
+    QQmlEngine::setObjectOwnership(trigger, QQmlEngine::CppOwnership);
+    return trigger;
+}
+
+LayerTrigger *VisualizerModel::weightTrigger(int layer)
+{
+    LayerTrigger* trigger = &m_layers[layer]->m_weightTrigger;
+    QQmlEngine::setObjectOwnership(trigger, QQmlEngine::CppOwnership);
+    return trigger;
+}

+ 6 - 0
gui/visualizermodel.h

@@ -31,13 +31,17 @@
 #include "remotecontrolclient.h"
 #include "valueindicator.h"
 #include "abstractdense.h"
+#include "layertrigger.h"
 
 class ValueIndicator;
+class LayerTrigger;
 
 struct NetworkLayerState {
     ValueIndicatorDense m_activations;
     ValueIndicatorDense m_biases;
     ValueIndicatorDense m_weights;
+    LayerTrigger m_actiovationTrigger;
+    LayerTrigger m_weightTrigger;
 };
 
 class VisualizerModel : public QObject
@@ -53,6 +57,8 @@ public:
 
     Q_INVOKABLE ValueIndicator *activation(int layer, int row);
     Q_INVOKABLE ValueIndicator *weight(int layer, int row, int column);
+    Q_INVOKABLE LayerTrigger *activationTrigger(int layer);
+    Q_INVOKABLE LayerTrigger *weightTrigger(int layer);
 
 signals:
     void sizesChanged();

+ 4 - 3
neuralnetwork/main.go

@@ -12,7 +12,7 @@ import (
 )
 
 func main() {
-	sizes := []int{13, 12, 8, 12, 3}
+	sizes := []int{13, 8, 12, 3}
 	nn, _ := neuralnetwork.NewNeuralNetwork(sizes, neuralnetwork.NewRPropInitializer(neuralnetwork.RPropConfig{
 		NuPlus:   1.2,
 		NuMinus:  0.5,
@@ -43,7 +43,7 @@ func main() {
 	go func() {
 		// teacher := teach.NewMNISTReader("./minst.data", "./mnist.labels")
 		teacher := teach.NewTextDataReader("wine.data", 5)
-		nn.Teach(teacher, 1500)
+		nn.Teach(teacher, 500)
 
 		// for i := 0; i < nn.Count; i++ {
 		// 	if i > 0 {
@@ -67,6 +67,8 @@ func main() {
 		teacher.Reset()
 		for true {
 			if !teacher.NextValidator() {
+				fmt.Printf("Fail count: %v\n\n", failCount)
+				failCount = 0
 				teacher.Reset()
 			}
 			dataSet, expect := teacher.GetValidator()
@@ -78,7 +80,6 @@ func main() {
 				// fmt.Printf("Fail: %v, %v\n\n", teacher.ValidationIndex(), expect.At(index, 0))
 			}
 		}
-		fmt.Printf("Fail count: %v\n\n", failCount)
 	}()
 
 	// nn = &neuralnetwork.NeuralNetwork{}