12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- /*
- * MIT License
- *
- * Copyright (c) 2020 Alexey Edelev <semlanik@gmail.com>
- *
- * This file is part of NeuralNetwork project https://git.semlanik.org/semlanik/NeuralNetwork
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy of this
- * software and associated documentation files (the "Software"), to deal in the Software
- * without restriction, including without limitation the rights to use, copy, modify,
- * merge, publish, distribute, sublicense, and/or sell copies of the Software, and
- * to permit persons to whom the Software is furnished to do so, subject to the following
- * conditions:
- *
- * The above copyright notice and this permission notice shall be included in all copies
- * or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
- * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
- * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
- * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
- * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
- * DEALINGS IN THE SOFTWARE.
- */
- package neuralnetwork
- import (
- "runtime"
- training "git.semlanik.org/semlanik/NeuralNetwork/training"
- mat "gonum.org/v1/gonum/mat"
- )
- type localBatchWorkerFactory struct {
- network *NeuralNetwork
- }
- type localBatchWorker struct {
- network *NeuralNetwork
- BGradient []BatchGradientDescent
- WGradient []BatchGradientDescent
- batchSize int
- }
- func NewLocalBatchWorkerFactory(network *NeuralNetwork) BatchWorkerFactory {
- factory := &localBatchWorkerFactory{
- network: network,
- }
- return factory
- }
- func newLocalBatchWorker(nn *NeuralNetwork) (bw *localBatchWorker) {
- bw = &localBatchWorker{
- network: nn,
- BGradient: make([]BatchGradientDescent, nn.LayerCount),
- WGradient: make([]BatchGradientDescent, nn.LayerCount),
- }
- for l := 1; l < nn.LayerCount; l++ {
- bw.BGradient[l] = nn.gradientDescentInitializer(nn, l, BiasGradient).(BatchGradientDescent)
- bw.WGradient[l] = nn.gradientDescentInitializer(nn, l, WeightGradient).(BatchGradientDescent)
- }
- return
- }
- func (bw *localBatchWorker) Run(trainer training.Trainer, startIndex, endIndex int) {
- for i := startIndex; i < endIndex; i++ {
- dB, dW := bw.network.backward(trainer.GetData(i))
- for l := 1; l < bw.network.LayerCount; l++ {
- bw.BGradient[l].AccumGradients(dB[l])
- bw.WGradient[l].AccumGradients(dW[l])
- }
- }
- }
- func (bw *localBatchWorker) Result(layer int) (dB, dW *mat.Dense) {
- return bw.BGradient[layer].Gradients(), bw.WGradient[layer].Gradients()
- }
- func (lbwf localBatchWorkerFactory) GetBatchWorker() BatchWorker {
- return newLocalBatchWorker(lbwf.network)
- }
- func (lbwf localBatchWorkerFactory) GetAvailableThreads() int {
- return runtime.NumCPU()
- }
|