Table of Contents
Correct answer
Home Backend Development Golang My neural network is trained (from scratch) to get further away from the goal

My neural network is trained (from scratch) to get further away from the goal

Feb 06, 2024 am 10:36 AM

My neural network is trained (from scratch) to get further away from the goal

Question content

This is my first time creating a neural network and I decided to create it in golang which is usually not the language used for this purpose , but I want to get a good understanding of how they work from scratch with only the basic libraries.

The goal of this program is to train a neural network to be able to add two numbers (1-10). To do this, I created a neural network class called rawai (the best name I could think of) and gave it a 1 input layer (array of size 2), 1 hidden layer (array of size 2 array) and 1 output layer (array of size 1).

The weight has two 2d arrays, one is ih (hidden input) [2,2], and the other is ho, [2,1].

The following is the code to start ai, train and test ai. You will see several debugging statements that I have used and any other functions that are not golang or its packages will be shown in the following code of my rawai class. This is called by my main function:

func additionneuralnetworktest() {
    nn := newrawai(2, 2, 1, 1/math.pow(10, 15))
    fmt.printf("weights ih before: %v\n\nweights ho after: %v\n", nn.weightsih, nn.weightsho)
    //train neural network
    //
    for epoch := 0; epoch < 10000000; epoch++ {
        for i := 0; i <= 10; i++ {
            for j := 0; j <= 10; j++ {
                inputs := make([]float64, 2)
                targets := make([]float64, 1)
                inputs[0] = float64(i)
                inputs[1] = float64(j)
                targets[0] = float64(i) + float64(j)
                nn.train(inputs, targets)
                if epoch%20000 == 0 && i == 5 && j == 5 {
                    fmt.printf("[training] [epoch %d] %f + %f = %f targets[%f]\n", epoch, inputs[0], inputs[1], nn.outputlayer[0], targets[0])
                }

            }

        }
    }
    // test neural network
    a := rand.intn(10) + 1
    b := rand.intn(10) + 1
    inputs := make([]float64, 2)
    inputs[0] = float64(a)
    inputs[1] = float64(b)
    prediction := nn.feedforward(inputs)[0]
    fmt.printf("%d + %d = %f\n", a, b, prediction)
    fmt.printf("weights ih: %v\n\nweights ho: %v\n", nn.weightsih, nn.weightsho)

}
Copy after login

The following is all the code in the rawai file:

type RawAI struct {
    InputLayer   []float64   `json:"input_layer"`
    HiddenLayer  []float64   `json:"hidden_layer"`
    OutputLayer  []float64   `json:"output_layer"`
    WeightsIH    [][]float64 `json:"weights_ih"`
    WeightsHO    [][]float64 `json:"weights_ho"`
    LearningRate float64     `json:"learning_rate"`
}

func NewRawAI(inputSize, hiddenSize, outputSize int, learningRate float64) *RawAI {
    nn := RawAI{
        InputLayer:   make([]float64, inputSize),
        HiddenLayer:  make([]float64, hiddenSize),
        OutputLayer:  make([]float64, outputSize),
        WeightsIH:    randomMatrix(inputSize, hiddenSize),
        WeightsHO:    randomMatrix(hiddenSize, outputSize),
        LearningRate: learningRate,
    }
    return &nn
}
func (nn *RawAI) FeedForward(inputs []float64) []float64 {
    // Set input layer
    for i := 0; i < len(inputs); i++ {
        nn.InputLayer[i] = inputs[i]
    }

    // Compute hidden layer
    for i := 0; i < len(nn.HiddenLayer); i++ {
        sum := 0.0
        for j := 0; j < len(nn.InputLayer); j++ {
            sum += nn.InputLayer[j] * nn.WeightsIH[j][i]
        }
        nn.HiddenLayer[i] = sum
        if math.IsNaN(sum) {
            panic(fmt.Sprintf("Sum is NaN on Hidden Layer:\nInput Layer: %v\nHidden Layer: %v\nWeights IH: %v\n", nn.InputLayer, nn.HiddenLayer, nn.WeightsIH))
        }

    }

    // Compute output layer
    for k := 0; k < len(nn.OutputLayer); k++ {
        sum := 0.0
        for j := 0; j < len(nn.HiddenLayer); j++ {
            sum += nn.HiddenLayer[j] * nn.WeightsHO[j][k]
        }
        nn.OutputLayer[k] = sum
        if math.IsNaN(sum) {
            panic(fmt.Sprintf("Sum is NaN on Output Layer:\n Model: %v\n", nn))
        }

    }

    return nn.OutputLayer
}
func (nn *RawAI) Train(inputs []float64, targets []float64) {
    nn.FeedForward(inputs)

    // Compute output layer error
    outputErrors := make([]float64, len(targets))
    for k := 0; k < len(targets); k++ {
        outputErrors[k] = targets[k] - nn.OutputLayer[k]
    }

    // Compute hidden layer error
    hiddenErrors := make([]float64, len(nn.HiddenLayer))
    for j := 0; j < len(nn.HiddenLayer); j++ {
        errorSum := 0.0
        for k := 0; k < len(nn.OutputLayer); k++ {
            errorSum += outputErrors[k] * nn.WeightsHO[j][k]
        }
        hiddenErrors[j] = errorSum * sigmoidDerivative(nn.HiddenLayer[j])
        if math.IsInf(math.Abs(hiddenErrors[j]), 1) {
            //Find out why
            fmt.Printf("Hidden Error is Infinite:\nTargets:%v\nOutputLayer:%v\n\n", targets, nn.OutputLayer)
        }
    }

    // Update weights
    for j := 0; j < len(nn.HiddenLayer); j++ {
        for k := 0; k < len(nn.OutputLayer); k++ {
            delta := nn.LearningRate * outputErrors[k] * nn.HiddenLayer[j]
            nn.WeightsHO[j][k] += delta
        }
    }
    for i := 0; i < len(nn.InputLayer); i++ {
        for j := 0; j < len(nn.HiddenLayer); j++ {
            delta := nn.LearningRate * hiddenErrors[j] * nn.InputLayer[i]
            nn.WeightsIH[i][j] += delta
            if math.IsNaN(delta) {
                fmt.Print(fmt.Sprintf("Delta is NaN.\n Learning Rate: %f\nHidden Errors: %f\nInput: %f\n", nn.LearningRate, hiddenErrors[j], nn.InputLayer[i]))
            }
            if math.IsNaN(nn.WeightsIH[i][j]) {
                fmt.Print(fmt.Sprintf("Delta is NaN.\n Learning Rate: %f\nHidden Errors: %f\nInput: %f\n", nn.LearningRate, hiddenErrors[j], nn.InputLayer[i]))
            }
        }
    }

}
func (nn *RawAI) ExportWeights(filename string) error {
    weightsJson, err := json.Marshal(nn)
    if err != nil {
        return err
    }
    err = ioutil.WriteFile(filename, weightsJson, 0644)
    if err != nil {
        return err
    }
    return nil
}
func (nn *RawAI) ImportWeights(filename string) error {
    weightsJson, err := ioutil.ReadFile(filename)
    if err != nil {
        return err
    }
    err = json.Unmarshal(weightsJson, nn)
    if err != nil {
        return err
    }
    return nil
}

//RawAI Tools:
func randomMatrix(rows, cols int) [][]float64 {
    matrix := make([][]float64, rows)
    for i := 0; i < rows; i++ {
        matrix[i] = make([]float64, cols)
        for j := 0; j < cols; j++ {
            matrix[i][j] = 1.0
        }
    }
    return matrix
}
func sigmoid(x float64) float64 {
    return 1.0 / (1.0 + exp(-x))
}
func sigmoidDerivative(x float64) float64 {
    return x * (1.0 - x)
}

func exp(x float64) float64 {
    return 1.0 + x + (x*x)/2.0 + (x*x*x)/6.0 + (x*x*x*x)/24.0
}
Copy after login

The output example is as follows: As you can see, it slowly moves away from the target and continues to do so. After asking, googling and searching this site I couldn't find where my error was so I decided to ask this question.


Correct answer


I think you are using the mean squared error and forgetting the - after differentiating.

So change:

outputerrors[k] =  (targets[k] - nn.outputlayer[k])
Copy after login

To:

outputErrors[k] = -(targets[k] - nn.OutputLayer[k])
Copy after login

The above is the detailed content of My neural network is trained (from scratch) to get further away from the goal. For more information, please follow other related articles on the PHP Chinese website!

Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn

Hot AI Tools

Undresser.AI Undress

Undresser.AI Undress

AI-powered app for creating realistic nude photos

AI Clothes Remover

AI Clothes Remover

Online AI tool for removing clothes from photos.

Undress AI Tool

Undress AI Tool

Undress images for free

Clothoff.io

Clothoff.io

AI clothes remover

AI Hentai Generator

AI Hentai Generator

Generate AI Hentai for free.

Hot Article

R.E.P.O. Energy Crystals Explained and What They Do (Yellow Crystal)
2 weeks ago By 尊渡假赌尊渡假赌尊渡假赌
Repo: How To Revive Teammates
4 weeks ago By 尊渡假赌尊渡假赌尊渡假赌
Hello Kitty Island Adventure: How To Get Giant Seeds
4 weeks ago By 尊渡假赌尊渡假赌尊渡假赌

Hot Tools

Notepad++7.3.1

Notepad++7.3.1

Easy-to-use and free code editor

SublimeText3 Chinese version

SublimeText3 Chinese version

Chinese version, very easy to use

Zend Studio 13.0.1

Zend Studio 13.0.1

Powerful PHP integrated development environment

Dreamweaver CS6

Dreamweaver CS6

Visual web development tools

SublimeText3 Mac version

SublimeText3 Mac version

God-level code editing software (SublimeText3)

Go language pack import: What is the difference between underscore and without underscore? Go language pack import: What is the difference between underscore and without underscore? Mar 03, 2025 pm 05:17 PM

This article explains Go's package import mechanisms: named imports (e.g., import &quot;fmt&quot;) and blank imports (e.g., import _ &quot;fmt&quot;). Named imports make package contents accessible, while blank imports only execute t

How to implement short-term information transfer between pages in the Beego framework? How to implement short-term information transfer between pages in the Beego framework? Mar 03, 2025 pm 05:22 PM

This article explains Beego's NewFlash() function for inter-page data transfer in web applications. It focuses on using NewFlash() to display temporary messages (success, error, warning) between controllers, leveraging the session mechanism. Limita

How to convert MySQL query result List into a custom structure slice in Go language? How to convert MySQL query result List into a custom structure slice in Go language? Mar 03, 2025 pm 05:18 PM

This article details efficient conversion of MySQL query results into Go struct slices. It emphasizes using database/sql's Scan method for optimal performance, avoiding manual parsing. Best practices for struct field mapping using db tags and robus

How do I write mock objects and stubs for testing in Go? How do I write mock objects and stubs for testing in Go? Mar 10, 2025 pm 05:38 PM

This article demonstrates creating mocks and stubs in Go for unit testing. It emphasizes using interfaces, provides examples of mock implementations, and discusses best practices like keeping mocks focused and using assertion libraries. The articl

How can I define custom type constraints for generics in Go? How can I define custom type constraints for generics in Go? Mar 10, 2025 pm 03:20 PM

This article explores Go's custom type constraints for generics. It details how interfaces define minimum type requirements for generic functions, improving type safety and code reusability. The article also discusses limitations and best practices

How to write files in Go language conveniently? How to write files in Go language conveniently? Mar 03, 2025 pm 05:15 PM

This article details efficient file writing in Go, comparing os.WriteFile (suitable for small files) with os.OpenFile and buffered writes (optimal for large files). It emphasizes robust error handling, using defer, and checking for specific errors.

How do you write unit tests in Go? How do you write unit tests in Go? Mar 21, 2025 pm 06:34 PM

The article discusses writing unit tests in Go, covering best practices, mocking techniques, and tools for efficient test management.

How can I use tracing tools to understand the execution flow of my Go applications? How can I use tracing tools to understand the execution flow of my Go applications? Mar 10, 2025 pm 05:36 PM

This article explores using tracing tools to analyze Go application execution flow. It discusses manual and automatic instrumentation techniques, comparing tools like Jaeger, Zipkin, and OpenTelemetry, and highlighting effective data visualization

See all articles