12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- package neuralnetworkbase
- import (
- teach "../teach"
- mat "gonum.org/v1/gonum/mat"
- )
- type batchWorker struct {
- network *NeuralNetwork
- BGradient []BatchGradientDescent
- WGradient []BatchGradientDescent
- batchSize int
- }
- func newBatchWorker(nn *NeuralNetwork) (bw *batchWorker) {
- bw = &batchWorker{
- network: nn,
- BGradient: make([]BatchGradientDescent, nn.LayerCount),
- WGradient: make([]BatchGradientDescent, nn.LayerCount),
- }
- for l := 1; l < nn.LayerCount; l++ {
- bw.BGradient[l] = nn.gradientDescentInitializer(nn, l, BiasGradient).(BatchGradientDescent)
- bw.WGradient[l] = nn.gradientDescentInitializer(nn, l, WeightGradient).(BatchGradientDescent)
- }
- return
- }
- func (bw *batchWorker) Run(teacher teach.Teacher, startIndex, endIndex int) {
- for i := startIndex; i < endIndex; i++ {
- dB, dW := bw.network.backward(teacher.GetDataByIndex(i))
- for l := 1; l < bw.network.LayerCount; l++ {
- bw.BGradient[l].AccumGradients(dB[l])
- bw.WGradient[l].AccumGradients(dW[l])
- }
- }
- teacher.Reset()
- }
- func (bw *batchWorker) Result(layer int) (dB, dW *mat.Dense) {
- return bw.BGradient[layer].Gradients(), bw.WGradient[layer].Gradients()
- }
|