Maison > développement back-end > Tutoriel Python > Comment implémenter la rétropropagation softmax en Python.

Comment implémenter la rétropropagation softmax en Python.

WBOY
Libérer: 2023-05-09 08:05:53
avant
1256 Les gens l'ont consulté

Dérivation de rétropropagation

Comme vous pouvez le voir, softmax calcule les entrées de plusieurs neurones. Lors de la dérivation de rétropropagation, vous devez envisager de dériver les paramètres de différents neurones.

Considérons deux situations :

  • Lorsque le paramètre de différenciation est situé au numérateur

  • Lorsque le paramètre de différenciation est situé au dénominateur

Comment implémenter la rétropropagation softmax en Python.

Lorsque le paramètre de différenciation est situé au numérateur :

Comment implémenter la rétropropagation softmax en Python.

Lorsque le paramètre de dérivation est situé au dénominateur (ez2 ou ez3 sont symétriques, le résultat de la dérivation est le même) :

Comment implémenter la rétropropagation softmax en Python.

Comment implémenter la rétropropagation softmax en Python.

Code

import torch
import math

def my_softmax(features):
    _sum = 0
    for i in features:
        _sum += math.e ** i
    return torch.Tensor([ math.e ** i / _sum for i in features ])

def my_softmax_grad(outputs):    
    n = len(outputs)
    grad = []
    for i in range(n):
        temp = []
        for j in range(n):
            if i == j:
                temp.append(outputs[i] * (1- outputs[i]))
            else:
                temp.append(-outputs[j] * outputs[i])
        grad.append(torch.Tensor(temp))
    return grad

if __name__ == '__main__':

    features = torch.randn(10)
    features.requires_grad_()

    torch_softmax = torch.nn.functional.softmax
    p1 = torch_softmax(features,dim=0)
    p2 = my_softmax(features)
    print(torch.allclose(p1,p2))
    
    n = len(p1)
    p2_grad = my_softmax_grad(p2)
    for i in range(n):
        p1_grad = torch.autograd.grad(p1[i],features, retain_graph=True)
        print(torch.allclose(p1_grad[0], p2_grad[i]))
Copier après la connexion

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Étiquettes associées:
source:yisu.com
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal