genetic.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. package genetic
  2. import (
  3. "log"
  4. "sort"
  5. neuralnetwork "../neuralnetwork"
  6. )
  7. type Population struct {
  8. populationSize int
  9. Networks []*neuralnetwork.NeuralNetwork
  10. verifier PopulationVerifier
  11. mutagen Mutagen
  12. }
  13. func NewPopulation(verifier PopulationVerifier, mutagen Mutagen, populationSize int, sizes []int) (p *Population) {
  14. if populationSize%2 != 0 {
  15. return nil
  16. }
  17. p = &Population{
  18. populationSize: populationSize,
  19. Networks: make([]*neuralnetwork.NeuralNetwork, populationSize),
  20. verifier: verifier,
  21. mutagen: mutagen,
  22. }
  23. for i := 0; i < populationSize; i++ {
  24. var err error
  25. p.Networks[i], err = neuralnetwork.NewNeuralNetwork(sizes, nil)
  26. if err != nil {
  27. log.Fatal("Could not initialize NeuralNetwork")
  28. }
  29. }
  30. return
  31. }
  32. func (p *Population) NaturalSelection(generationCount int) {
  33. for g := 0; g < generationCount; g++ {
  34. p.crossbreedPopulation(p.verifier.Verify(p))
  35. }
  36. }
  37. func (p *Population) crossbreedPopulation(results []*IndividalResult) {
  38. sort.Slice(results, func(i, j int) bool {
  39. return results[i].result < results[j].result
  40. })
  41. for i := 1; i < p.populationSize; i += 2 {
  42. firstParent := results[i].index
  43. secondParent := results[i-1].index
  44. crossbreed(p.Networks[firstParent], p.Networks[secondParent])
  45. p.mutagen.Mutate(p.Networks[firstParent])
  46. p.mutagen.Mutate(p.Networks[secondParent])
  47. }
  48. }
  49. func crossbreed(firstParent, secondParent *neuralnetwork.NeuralNetwork) {
  50. for l := 1; l < firstParent.LayerCount; l++ {
  51. firstParentWeights := firstParent.Weights[l]
  52. secondParentWeights := secondParent.Weights[l]
  53. firstParentBiases := firstParent.Biases[l]
  54. secondParentBiases := secondParent.Biases[l]
  55. r, c := firstParentWeights.Dims()
  56. for i := 0; i < r/2; i++ {
  57. for j := 0; j < c; j++ {
  58. // Swap first half of weights
  59. w := firstParentWeights.At(i, j)
  60. firstParentWeights.Set(i, j, secondParentWeights.At(i, j))
  61. secondParentWeights.Set(i, j, w)
  62. }
  63. // Swap first half of biases
  64. b := firstParentBiases.At(i, 0)
  65. firstParentBiases.Set(i, 0, secondParentBiases.At(i, 0))
  66. secondParentBiases.Set(i, 0, b)
  67. }
  68. }
  69. }