Table des matières
Bonne réponse
Maison développement back-end Golang Mon réseau neuronal est entraîné (à partir de zéro) pour s'éloigner davantage de l'objectif

Mon réseau neuronal est entraîné (à partir de zéro) pour s'éloigner davantage de l'objectif

Feb 06, 2024 am 10:36 AM

Mon réseau neuronal est entraîné (à partir de zéro) pour séloigner davantage de lobjectif

Contenu de la question

C'est la première fois que je crée un réseau de neurones et j'ai décidé de le créer en Golang qui n'est généralement pas le langage prévu à cet effet mais je veux bien comprendre comment ils fonctionnent à partir de zéro uniquement Bibliothèque de base.

Le but de ce programme est de former un réseau de neurones pour pouvoir additionner deux nombres (1-10). Pour ce faire, j'ai créé une classe de réseau neuronal appelée rawai (le meilleur nom auquel je puisse penser) et lui ai donné 1 couche d'entrée (tableau de taille 2), 1 couche cachée (tableau de taille 2) et 1 couche de sortie ( tableau de taille 1).

Les poids ont deux tableaux 2D, l'un est ih (entrée cachée) [2,2] et l'autre est ho, [2,1].

Ce qui suit est le code pour démarrer, entraîner et tester l'IA. Vous verrez plusieurs instructions de débogage que j'ai utilisées et toutes les autres fonctions qui ne sont pas Golang ou ses packages seront affichées dans le code suivant de ma classe rawai. Ceci est appelé par ma fonction principale :

func additionneuralnetworktest() {
    nn := newrawai(2, 2, 1, 1/math.pow(10, 15))
    fmt.printf("weights ih before: %v\n\nweights ho after: %v\n", nn.weightsih, nn.weightsho)
    //train neural network
    //
    for epoch := 0; epoch < 10000000; epoch++ {
        for i := 0; i <= 10; i++ {
            for j := 0; j <= 10; j++ {
                inputs := make([]float64, 2)
                targets := make([]float64, 1)
                inputs[0] = float64(i)
                inputs[1] = float64(j)
                targets[0] = float64(i) + float64(j)
                nn.train(inputs, targets)
                if epoch%20000 == 0 && i == 5 && j == 5 {
                    fmt.printf("[training] [epoch %d] %f + %f = %f targets[%f]\n", epoch, inputs[0], inputs[1], nn.outputlayer[0], targets[0])
                }

            }

        }
    }
    // test neural network
    a := rand.intn(10) + 1
    b := rand.intn(10) + 1
    inputs := make([]float64, 2)
    inputs[0] = float64(a)
    inputs[1] = float64(b)
    prediction := nn.feedforward(inputs)[0]
    fmt.printf("%d + %d = %f\n", a, b, prediction)
    fmt.printf("weights ih: %v\n\nweights ho: %v\n", nn.weightsih, nn.weightsho)

}
Copier après la connexion

Voici tout le code dans le fichier rawai :

type RawAI struct {
    InputLayer   []float64   `json:"input_layer"`
    HiddenLayer  []float64   `json:"hidden_layer"`
    OutputLayer  []float64   `json:"output_layer"`
    WeightsIH    [][]float64 `json:"weights_ih"`
    WeightsHO    [][]float64 `json:"weights_ho"`
    LearningRate float64     `json:"learning_rate"`
}

func NewRawAI(inputSize, hiddenSize, outputSize int, learningRate float64) *RawAI {
    nn := RawAI{
        InputLayer:   make([]float64, inputSize),
        HiddenLayer:  make([]float64, hiddenSize),
        OutputLayer:  make([]float64, outputSize),
        WeightsIH:    randomMatrix(inputSize, hiddenSize),
        WeightsHO:    randomMatrix(hiddenSize, outputSize),
        LearningRate: learningRate,
    }
    return &nn
}
func (nn *RawAI) FeedForward(inputs []float64) []float64 {
    // Set input layer
    for i := 0; i < len(inputs); i++ {
        nn.InputLayer[i] = inputs[i]
    }

    // Compute hidden layer
    for i := 0; i < len(nn.HiddenLayer); i++ {
        sum := 0.0
        for j := 0; j < len(nn.InputLayer); j++ {
            sum += nn.InputLayer[j] * nn.WeightsIH[j][i]
        }
        nn.HiddenLayer[i] = sum
        if math.IsNaN(sum) {
            panic(fmt.Sprintf("Sum is NaN on Hidden Layer:\nInput Layer: %v\nHidden Layer: %v\nWeights IH: %v\n", nn.InputLayer, nn.HiddenLayer, nn.WeightsIH))
        }

    }

    // Compute output layer
    for k := 0; k < len(nn.OutputLayer); k++ {
        sum := 0.0
        for j := 0; j < len(nn.HiddenLayer); j++ {
            sum += nn.HiddenLayer[j] * nn.WeightsHO[j][k]
        }
        nn.OutputLayer[k] = sum
        if math.IsNaN(sum) {
            panic(fmt.Sprintf("Sum is NaN on Output Layer:\n Model: %v\n", nn))
        }

    }

    return nn.OutputLayer
}
func (nn *RawAI) Train(inputs []float64, targets []float64) {
    nn.FeedForward(inputs)

    // Compute output layer error
    outputErrors := make([]float64, len(targets))
    for k := 0; k < len(targets); k++ {
        outputErrors[k] = targets[k] - nn.OutputLayer[k]
    }

    // Compute hidden layer error
    hiddenErrors := make([]float64, len(nn.HiddenLayer))
    for j := 0; j < len(nn.HiddenLayer); j++ {
        errorSum := 0.0
        for k := 0; k < len(nn.OutputLayer); k++ {
            errorSum += outputErrors[k] * nn.WeightsHO[j][k]
        }
        hiddenErrors[j] = errorSum * sigmoidDerivative(nn.HiddenLayer[j])
        if math.IsInf(math.Abs(hiddenErrors[j]), 1) {
            //Find out why
            fmt.Printf("Hidden Error is Infinite:\nTargets:%v\nOutputLayer:%v\n\n", targets, nn.OutputLayer)
        }
    }

    // Update weights
    for j := 0; j < len(nn.HiddenLayer); j++ {
        for k := 0; k < len(nn.OutputLayer); k++ {
            delta := nn.LearningRate * outputErrors[k] * nn.HiddenLayer[j]
            nn.WeightsHO[j][k] += delta
        }
    }
    for i := 0; i < len(nn.InputLayer); i++ {
        for j := 0; j < len(nn.HiddenLayer); j++ {
            delta := nn.LearningRate * hiddenErrors[j] * nn.InputLayer[i]
            nn.WeightsIH[i][j] += delta
            if math.IsNaN(delta) {
                fmt.Print(fmt.Sprintf("Delta is NaN.\n Learning Rate: %f\nHidden Errors: %f\nInput: %f\n", nn.LearningRate, hiddenErrors[j], nn.InputLayer[i]))
            }
            if math.IsNaN(nn.WeightsIH[i][j]) {
                fmt.Print(fmt.Sprintf("Delta is NaN.\n Learning Rate: %f\nHidden Errors: %f\nInput: %f\n", nn.LearningRate, hiddenErrors[j], nn.InputLayer[i]))
            }
        }
    }

}
func (nn *RawAI) ExportWeights(filename string) error {
    weightsJson, err := json.Marshal(nn)
    if err != nil {
        return err
    }
    err = ioutil.WriteFile(filename, weightsJson, 0644)
    if err != nil {
        return err
    }
    return nil
}
func (nn *RawAI) ImportWeights(filename string) error {
    weightsJson, err := ioutil.ReadFile(filename)
    if err != nil {
        return err
    }
    err = json.Unmarshal(weightsJson, nn)
    if err != nil {
        return err
    }
    return nil
}

//RawAI Tools:
func randomMatrix(rows, cols int) [][]float64 {
    matrix := make([][]float64, rows)
    for i := 0; i < rows; i++ {
        matrix[i] = make([]float64, cols)
        for j := 0; j < cols; j++ {
            matrix[i][j] = 1.0
        }
    }
    return matrix
}
func sigmoid(x float64) float64 {
    return 1.0 / (1.0 + exp(-x))
}
func sigmoidDerivative(x float64) float64 {
    return x * (1.0 - x)
}

func exp(x float64) float64 {
    return 1.0 + x + (x*x)/2.0 + (x*x*x)/6.0 + (x*x*x*x)/24.0
}
Copier après la connexion

L'exemple de sortie est le suivant : Comme vous pouvez le constater, il s’éloigne lentement de la cible et continue de s’éloigner. Après avoir demandé, recherché sur Google et recherché ce site, je n'ai pas trouvé où se trouvait mon erreur, j'ai donc décidé de poser cette question.


Bonne réponse


Je pense que vous utilisez 均方误差 并在微分后忘记了 - .

Alors change :

outputerrors[k] =  (targets[k] - nn.outputlayer[k])
Copier après la connexion

À :

outputErrors[k] = -(targets[k] - nn.OutputLayer[k])
Copier après la connexion

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn

Outils d'IA chauds

Undresser.AI Undress

Undresser.AI Undress

Application basée sur l'IA pour créer des photos de nu réalistes

AI Clothes Remover

AI Clothes Remover

Outil d'IA en ligne pour supprimer les vêtements des photos.

Undress AI Tool

Undress AI Tool

Images de déshabillage gratuites

Clothoff.io

Clothoff.io

Dissolvant de vêtements AI

AI Hentai Generator

AI Hentai Generator

Générez AI Hentai gratuitement.

Article chaud

R.E.P.O. Crystals d'énergie expliqués et ce qu'ils font (cristal jaune)
1 Il y a quelques mois By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O. Meilleurs paramètres graphiques
1 Il y a quelques mois By 尊渡假赌尊渡假赌尊渡假赌
Will R.E.P.O. Vous avez un jeu croisé?
1 Il y a quelques mois By 尊渡假赌尊渡假赌尊渡假赌

Outils chauds

Bloc-notes++7.3.1

Bloc-notes++7.3.1

Éditeur de code facile à utiliser et gratuit

SublimeText3 version chinoise

SublimeText3 version chinoise

Version chinoise, très simple à utiliser

Envoyer Studio 13.0.1

Envoyer Studio 13.0.1

Puissant environnement de développement intégré PHP

Dreamweaver CS6

Dreamweaver CS6

Outils de développement Web visuel

SublimeText3 version Mac

SublimeText3 version Mac

Logiciel d'édition de code au niveau de Dieu (SublimeText3)

Quelles sont les vulnérabilités de Debian OpenSSL Quelles sont les vulnérabilités de Debian OpenSSL Apr 02, 2025 am 07:30 AM

OpenSSL, en tant que bibliothèque open source largement utilisée dans les communications sécurisées, fournit des algorithmes de chiffrement, des clés et des fonctions de gestion des certificats. Cependant, il existe des vulnérabilités de sécurité connues dans sa version historique, dont certaines sont extrêmement nocives. Cet article se concentrera sur les vulnérabilités et les mesures de réponse communes pour OpenSSL dans Debian Systems. DebianopenSSL CONNUTS Vulnérabilités: OpenSSL a connu plusieurs vulnérabilités graves, telles que: la vulnérabilité des saignements cardiaques (CVE-2014-0160): cette vulnérabilité affecte OpenSSL 1.0.1 à 1.0.1F et 1.0.2 à 1.0.2 Versions bêta. Un attaquant peut utiliser cette vulnérabilité à des informations sensibles en lecture non autorisées sur le serveur, y compris les clés de chiffrement, etc.

Comment utilisez-vous l'outil PPROF pour analyser les performances GO? Comment utilisez-vous l'outil PPROF pour analyser les performances GO? Mar 21, 2025 pm 06:37 PM

L'article explique comment utiliser l'outil PPROF pour analyser les performances GO, notamment l'activation du profilage, la collecte de données et l'identification des goulots d'étranglement communs comme le processeur et les problèmes de mémoire. COMMANDE: 159

Comment rédigez-vous des tests unitaires en Go? Comment rédigez-vous des tests unitaires en Go? Mar 21, 2025 pm 06:34 PM

L'article traite des tests d'unité d'écriture dans GO, couvrant les meilleures pratiques, des techniques de moquerie et des outils pour une gestion efficace des tests.

Quelles bibliothèques sont utilisées pour les opérations du numéro de point flottantes en Go? Quelles bibliothèques sont utilisées pour les opérations du numéro de point flottantes en Go? Apr 02, 2025 pm 02:06 PM

La bibliothèque utilisée pour le fonctionnement du numéro de point flottante dans le langage go présente comment s'assurer que la précision est ...

Quel est le problème avec le fil de file d'attente dans GO's Crawler Colly? Quel est le problème avec le fil de file d'attente dans GO's Crawler Colly? Apr 02, 2025 pm 02:09 PM

Problème de threading de file d'attente dans Go Crawler Colly explore le problème de l'utilisation de la bibliothèque Crawler Crawler dans le langage Go, les développeurs rencontrent souvent des problèmes avec les threads et les files d'attente de demande. � ...

Quelle est la commande Go FMT et pourquoi est-elle importante? Quelle est la commande Go FMT et pourquoi est-elle importante? Mar 20, 2025 pm 04:21 PM

L'article traite de la commande GO FMT dans GO Programming, qui formate le code pour adhérer aux directives de style officiel. Il met en évidence l'importance de GO FMT pour maintenir la cohérence du code, la lisibilité et la réduction des débats de style. Meilleures pratiques pour

Transformant du développement frontal au développement back-end, est-il plus prometteur d'apprendre Java ou Golang? Transformant du développement frontal au développement back-end, est-il plus prometteur d'apprendre Java ou Golang? Apr 02, 2025 am 09:12 AM

Chemin d'apprentissage du backend: le parcours d'exploration du front-end à l'arrière-end en tant que débutant back-end qui se transforme du développement frontal, vous avez déjà la base de Nodejs, ...

Méthode de surveillance postgresql sous Debian Méthode de surveillance postgresql sous Debian Apr 02, 2025 am 07:27 AM

Cet article présente une variété de méthodes et d'outils pour surveiller les bases de données PostgreSQL sous le système Debian, vous aidant à saisir pleinement la surveillance des performances de la base de données. 1. Utilisez PostgreSQL pour reprendre la surveillance Afficher PostgreSQL lui-même offre plusieurs vues pour surveiller les activités de la base de données: PG_STAT_ACTIVITY: affiche les activités de la base de données en temps réel, y compris les connexions, les requêtes, les transactions et autres informations. PG_STAT_REPLIcation: surveille l'état de réplication, en particulier adapté aux grappes de réplication de flux. PG_STAT_DATABASE: Fournit des statistiques de base de données, telles que la taille de la base de données, les temps de validation / recul des transactions et d'autres indicateurs clés. 2. Utilisez l'outil d'analyse de journaux pgbadg

See all articles