Alexey Edelev пре 5 година
родитељ
комит
777e091384

+ 15 - 3
gui/main.qml

@@ -70,8 +70,9 @@ ApplicationWindow {
                             color: "#00ffffff"
                             Rectangle {
                                 id: neuron
-                                anchors.fill: parent
                                 property ValueIndicator activation: visualizerModel.activation(layerIndex, model.index)
+
+                                anchors.fill: parent
                                 radius: 15
                                 color: {
                                     var alpha = activation.value
@@ -100,7 +101,13 @@ ApplicationWindow {
                         var coordPrev = layerRepeater.itemAt(i - 1).mapToItem(root.contentItem, neuronPrev.x + neuronPrev.width/2, neuronPrev.y + neuronPrev.height/2)
                         var angle =  Math.atan2(coordPrev.y - coord.y, coordPrev.x - coord.x) * 180 / Math.PI
                         var length = Math.sqrt(Math.pow(coordPrev.x - coord.x, 2) + Math.pow(coordPrev.y - coord.y, 2))
-                        connection.createObject(bottomLayer, {x: coord.x, y: coord.y, width: length, angle: angle })
+                        connection.createObject(bottomLayer, {
+                                                    x: coord.x,
+                                                    y: coord.y,
+                                                    width: length,
+                                                    angle: angle,
+                                                    weight: visualizerModel.weight(i, j, k),
+                                                })
                     }
                 }
             }
@@ -111,13 +118,18 @@ ApplicationWindow {
         id: connection
         Rectangle {
             property alias angle: trans.angle
+            property ValueIndicator weight: null
+
             transformOrigin: Item.Left
             transform: Rotation {
                 id: trans
             }
 
             height: 1
-            color: "#55ffffff"
+            color: {
+                var color = weight.value
+                Qt.rgba(color, 0, 1.0/color, color > 0 ? 0.5 : 0.0)
+            }
         }
     }
 }

+ 30 - 1
gui/valueindicator.cpp

@@ -25,23 +25,52 @@
 
 #include "valueindicator.h"
 
+#include <limits>
+#include <cmath>
+
+#include <QDebug>
+
 ValueIndicator::ValueIndicator() : QObject()
   , m_value(0)
 {
 
 }
 
+qreal ValueIndicator::value() const
+{
+    return m_value;
+}
+
 
 ValueIndicatorDense::ValueIndicatorDense(int rows, int columns, const QList<ValueIndicator*>& data) : AbstractDense(rows, columns, data)
+  , m_max(std::numeric_limits<double>::min())
+  , m_min(std::numeric_limits<double>::max())
 {
 
 }
 
 void ValueIndicatorDense::updateValues(const Dense& dense)
 {
+    m_max = std::numeric_limits<double>::min();
+    m_min = std::numeric_limits<double>::max();
+    for(int i = 0; i < dense.rows(); i++) {
+        for(int j = 0; j < dense.columns(); j++) {
+            double val = dense.value<double>(i, j);
+            if (val > m_max) {
+                m_max = val;
+            }
+
+            if (val < m_min) {
+                m_min = val;
+            }
+        }
+    }
+
     for(int i = 0; i < dense.rows(); i++) {
         for(int j = 0; j < dense.columns(); j++) {
-            value<ValueIndicator*>(i,j)->setValue(dense.value<double>(i, j));
+            double val = dense.value<double>(i, j);
+
+            value<ValueIndicator*>(i,j)->setValue((val - m_min)/(m_max - m_min));
         }
     }
 }

+ 15 - 5
gui/valueindicator.h

@@ -29,20 +29,18 @@
 #include "abstractdense.h"
 #include "dense.h"
 
+class ValueIndicatorDense;
 class ValueIndicator : public QObject
 {
     Q_OBJECT
     Q_PROPERTY(qreal value READ value WRITE setValue NOTIFY valueChanged)
 public:
     ValueIndicator();
-    qreal value() const
-    {
-        return m_value;
-    }
+    qreal value() const;
+
 public slots:
     void setValue(qreal value)
     {
-//        qWarning("Floating point comparison needs context sanity check");
         if (qFuzzyCompare(m_value, value))
             return;
 
@@ -61,4 +59,16 @@ public:
     virtual ~ValueIndicatorDense();
     ValueIndicatorDense(int rows, int columns, const QList<ValueIndicator*>& data);
     void updateValues(const Dense& dense);
+
+    double min() const {
+        return m_min;
+    }
+
+    double max() const {
+        return m_max;
+    }
+
+private:
+    double m_max;
+    double m_min;
 };

+ 11 - 2
gui/visualizermodel.cpp

@@ -48,6 +48,16 @@ VisualizerModel::VisualizerModel(std::shared_ptr<RemoteControlClient> &client, Q
                 data.append(new ValueIndicator);
             }
             m_layers.last()->m_activations.setData(data);
+
+            if (i != 0) {
+                int tolalItems = m_networkConfig.sizes()[i]*m_networkConfig.sizes()[i - 1];
+                m_layers.last()->m_weights.setDimentions(m_networkConfig.sizes()[i], m_networkConfig.sizes()[i - 1]);
+                data.clear();
+                for (int k = 0; k < tolalItems; k++) {
+                    data.append(new ValueIndicator);
+                }
+                m_layers.last()->m_weights.setData(data);
+            }
         }
         sizesChanged();
     });
@@ -71,8 +81,7 @@ VisualizerModel::VisualizerModel(std::shared_ptr<RemoteControlClient> &client, Q
         if (m_layers.isEmpty()) {
             return;
         }
-        Dense dense(weights.matrix().matrix());
-//        qDebug() << "WeightsUpdated:" << dense.rows() << dense.columns();
+         m_layers[weights.layer()]->m_weights.updateValues(Dense(weights.matrix().matrix()));
     });
     client->subscribeActivationsUpdates({});
     client->subscribeBiasesUpdates({});

+ 6 - 2
neuralnetwork/main.go

@@ -12,7 +12,7 @@ import (
 )
 
 func main() {
-	sizes := []int{13, 12, 12, 3}
+	sizes := []int{13, 12, 8, 12, 3}
 	nn, _ := neuralnetwork.NewNeuralNetwork(sizes, neuralnetwork.NewRPropInitializer(neuralnetwork.RPropConfig{
 		NuPlus:   1.2,
 		NuMinus:  0.5,
@@ -65,9 +65,13 @@ func main() {
 		time.Sleep(5 * time.Second)
 		failCount := 0
 		teacher.Reset()
-		for teacher.NextValidator() {
+		for true {
+			if !teacher.NextValidator() {
+				teacher.Reset()
+			}
 			dataSet, expect := teacher.GetValidator()
 			index, _ := nn.Predict(dataSet)
+			//TODO: remove this is not used for visualization
 			time.Sleep(400 * time.Millisecond)
 			if expect.At(index, 0) != 1.0 {
 				failCount++

+ 3 - 0
neuralnetwork/neuralnetworkbase/neuralnetwork.go

@@ -32,6 +32,7 @@ import (
 	"io"
 	"runtime"
 	"sync"
+	"time"
 
 	teach "../teach"
 	mat "gonum.org/v1/gonum/mat"
@@ -227,6 +228,8 @@ func (nn *NeuralNetwork) TeachBatch(teacher teach.Teacher, epocs int) {
 				nn.watcher.UpdateWeights(l, nn.Weights[l])
 			}
 		}
+		//TODO: remove this is not used for visualization
+		time.Sleep(100 * time.Millisecond)
 	}
 }