Jadual Kandungan
Jawapan betul
Rumah pembangunan bahagian belakang Golang Rangkaian saraf saya dilatih (dari awal) untuk pergi lebih jauh daripada matlamat

Rangkaian saraf saya dilatih (dari awal) untuk pergi lebih jauh daripada matlamat

Feb 06, 2024 am 10:36 AM

Rangkaian saraf saya dilatih (dari awal) untuk pergi lebih jauh daripada matlamat

Kandungan soalan

Ini adalah kali pertama saya mencipta rangkaian saraf dan saya memutuskan untuk menciptanya dalam golang yang biasanya bukan bahasa untuk tujuan ini tetapi saya ingin memahami dengan baik bagaimana ia berfungsi dari awal sahaja Asasperpustakaan.

Matlamat program ini adalah untuk melatih rangkaian saraf supaya dapat menambah dua nombor (1-10). Untuk melakukan ini, saya mencipta kelas rangkaian saraf yang dipanggil rawai (nama terbaik yang boleh saya fikirkan) dan memberikannya 1 lapisan input (tatasusunan saiz 2), 1 lapisan tersembunyi (tatasusunan saiz 2) dan 1 lapisan output ( tatasusunan saiz 1).

Berat mempunyai dua tatasusunan 2d, satu ialah ih (input tersembunyi) [2,2], dan satu lagi ialah ho, [2,1].

Berikut ialah kod untuk memulakan, melatih dan menguji ai. Anda akan melihat beberapa pernyataan penyahpepijatan yang telah saya gunakan dan sebarang fungsi lain yang bukan golang atau pakejnya akan ditunjukkan dalam kod kelas rawai saya yang berikut. Ini dipanggil oleh fungsi utama saya:

func additionneuralnetworktest() {
    nn := newrawai(2, 2, 1, 1/math.pow(10, 15))
    fmt.printf("weights ih before: %v\n\nweights ho after: %v\n", nn.weightsih, nn.weightsho)
    //train neural network
    //
    for epoch := 0; epoch < 10000000; epoch++ {
        for i := 0; i <= 10; i++ {
            for j := 0; j <= 10; j++ {
                inputs := make([]float64, 2)
                targets := make([]float64, 1)
                inputs[0] = float64(i)
                inputs[1] = float64(j)
                targets[0] = float64(i) + float64(j)
                nn.train(inputs, targets)
                if epoch%20000 == 0 && i == 5 && j == 5 {
                    fmt.printf("[training] [epoch %d] %f + %f = %f targets[%f]\n", epoch, inputs[0], inputs[1], nn.outputlayer[0], targets[0])
                }

            }

        }
    }
    // test neural network
    a := rand.intn(10) + 1
    b := rand.intn(10) + 1
    inputs := make([]float64, 2)
    inputs[0] = float64(a)
    inputs[1] = float64(b)
    prediction := nn.feedforward(inputs)[0]
    fmt.printf("%d + %d = %f\n", a, b, prediction)
    fmt.printf("weights ih: %v\n\nweights ho: %v\n", nn.weightsih, nn.weightsho)

}
Salin selepas log masuk

Ini semua kod dalam fail rawai:

type RawAI struct {
    InputLayer   []float64   `json:"input_layer"`
    HiddenLayer  []float64   `json:"hidden_layer"`
    OutputLayer  []float64   `json:"output_layer"`
    WeightsIH    [][]float64 `json:"weights_ih"`
    WeightsHO    [][]float64 `json:"weights_ho"`
    LearningRate float64     `json:"learning_rate"`
}

func NewRawAI(inputSize, hiddenSize, outputSize int, learningRate float64) *RawAI {
    nn := RawAI{
        InputLayer:   make([]float64, inputSize),
        HiddenLayer:  make([]float64, hiddenSize),
        OutputLayer:  make([]float64, outputSize),
        WeightsIH:    randomMatrix(inputSize, hiddenSize),
        WeightsHO:    randomMatrix(hiddenSize, outputSize),
        LearningRate: learningRate,
    }
    return &nn
}
func (nn *RawAI) FeedForward(inputs []float64) []float64 {
    // Set input layer
    for i := 0; i < len(inputs); i++ {
        nn.InputLayer[i] = inputs[i]
    }

    // Compute hidden layer
    for i := 0; i < len(nn.HiddenLayer); i++ {
        sum := 0.0
        for j := 0; j < len(nn.InputLayer); j++ {
            sum += nn.InputLayer[j] * nn.WeightsIH[j][i]
        }
        nn.HiddenLayer[i] = sum
        if math.IsNaN(sum) {
            panic(fmt.Sprintf("Sum is NaN on Hidden Layer:\nInput Layer: %v\nHidden Layer: %v\nWeights IH: %v\n", nn.InputLayer, nn.HiddenLayer, nn.WeightsIH))
        }

    }

    // Compute output layer
    for k := 0; k < len(nn.OutputLayer); k++ {
        sum := 0.0
        for j := 0; j < len(nn.HiddenLayer); j++ {
            sum += nn.HiddenLayer[j] * nn.WeightsHO[j][k]
        }
        nn.OutputLayer[k] = sum
        if math.IsNaN(sum) {
            panic(fmt.Sprintf("Sum is NaN on Output Layer:\n Model: %v\n", nn))
        }

    }

    return nn.OutputLayer
}
func (nn *RawAI) Train(inputs []float64, targets []float64) {
    nn.FeedForward(inputs)

    // Compute output layer error
    outputErrors := make([]float64, len(targets))
    for k := 0; k < len(targets); k++ {
        outputErrors[k] = targets[k] - nn.OutputLayer[k]
    }

    // Compute hidden layer error
    hiddenErrors := make([]float64, len(nn.HiddenLayer))
    for j := 0; j < len(nn.HiddenLayer); j++ {
        errorSum := 0.0
        for k := 0; k < len(nn.OutputLayer); k++ {
            errorSum += outputErrors[k] * nn.WeightsHO[j][k]
        }
        hiddenErrors[j] = errorSum * sigmoidDerivative(nn.HiddenLayer[j])
        if math.IsInf(math.Abs(hiddenErrors[j]), 1) {
            //Find out why
            fmt.Printf("Hidden Error is Infinite:\nTargets:%v\nOutputLayer:%v\n\n", targets, nn.OutputLayer)
        }
    }

    // Update weights
    for j := 0; j < len(nn.HiddenLayer); j++ {
        for k := 0; k < len(nn.OutputLayer); k++ {
            delta := nn.LearningRate * outputErrors[k] * nn.HiddenLayer[j]
            nn.WeightsHO[j][k] += delta
        }
    }
    for i := 0; i < len(nn.InputLayer); i++ {
        for j := 0; j < len(nn.HiddenLayer); j++ {
            delta := nn.LearningRate * hiddenErrors[j] * nn.InputLayer[i]
            nn.WeightsIH[i][j] += delta
            if math.IsNaN(delta) {
                fmt.Print(fmt.Sprintf("Delta is NaN.\n Learning Rate: %f\nHidden Errors: %f\nInput: %f\n", nn.LearningRate, hiddenErrors[j], nn.InputLayer[i]))
            }
            if math.IsNaN(nn.WeightsIH[i][j]) {
                fmt.Print(fmt.Sprintf("Delta is NaN.\n Learning Rate: %f\nHidden Errors: %f\nInput: %f\n", nn.LearningRate, hiddenErrors[j], nn.InputLayer[i]))
            }
        }
    }

}
func (nn *RawAI) ExportWeights(filename string) error {
    weightsJson, err := json.Marshal(nn)
    if err != nil {
        return err
    }
    err = ioutil.WriteFile(filename, weightsJson, 0644)
    if err != nil {
        return err
    }
    return nil
}
func (nn *RawAI) ImportWeights(filename string) error {
    weightsJson, err := ioutil.ReadFile(filename)
    if err != nil {
        return err
    }
    err = json.Unmarshal(weightsJson, nn)
    if err != nil {
        return err
    }
    return nil
}

//RawAI Tools:
func randomMatrix(rows, cols int) [][]float64 {
    matrix := make([][]float64, rows)
    for i := 0; i < rows; i++ {
        matrix[i] = make([]float64, cols)
        for j := 0; j < cols; j++ {
            matrix[i][j] = 1.0
        }
    }
    return matrix
}
func sigmoid(x float64) float64 {
    return 1.0 / (1.0 + exp(-x))
}
func sigmoidDerivative(x float64) float64 {
    return x * (1.0 - x)
}

func exp(x float64) float64 {
    return 1.0 + x + (x*x)/2.0 + (x*x*x)/6.0 + (x*x*x*x)/24.0
}
Salin selepas log masuk

Contoh output adalah seperti berikut: Seperti yang anda lihat, ia perlahan-lahan bergerak dari sasaran dan terus melakukannya. Selepas bertanya, googling dan mencari tapak ini saya tidak dapat mencari di mana ralat saya jadi saya memutuskan untuk bertanya soalan ini.


Jawapan betul


Saya rasa anda sedang menggunakan 均方误差 并在微分后忘记了 - .

Jadi tukar:

outputerrors[k] =  (targets[k] - nn.outputlayer[k])
Salin selepas log masuk

Kepada:

outputErrors[k] = -(targets[k] - nn.OutputLayer[k])
Salin selepas log masuk

Atas ialah kandungan terperinci Rangkaian saraf saya dilatih (dari awal) untuk pergi lebih jauh daripada matlamat. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan Laman Web ini
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn

Alat AI Hot

Undresser.AI Undress

Undresser.AI Undress

Apl berkuasa AI untuk mencipta foto bogel yang realistik

AI Clothes Remover

AI Clothes Remover

Alat AI dalam talian untuk mengeluarkan pakaian daripada foto.

Undress AI Tool

Undress AI Tool

Gambar buka pakaian secara percuma

Clothoff.io

Clothoff.io

Penyingkiran pakaian AI

AI Hentai Generator

AI Hentai Generator

Menjana ai hentai secara percuma.

Artikel Panas

R.E.P.O. Kristal tenaga dijelaskan dan apa yang mereka lakukan (kristal kuning)
1 bulan yang lalu By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O. Tetapan grafik terbaik
1 bulan yang lalu By 尊渡假赌尊渡假赌尊渡假赌
Akan R.E.P.O. Ada Crossplay?
1 bulan yang lalu By 尊渡假赌尊渡假赌尊渡假赌

Alat panas

Notepad++7.3.1

Notepad++7.3.1

Editor kod yang mudah digunakan dan percuma

SublimeText3 versi Cina

SublimeText3 versi Cina

Versi Cina, sangat mudah digunakan

Hantar Studio 13.0.1

Hantar Studio 13.0.1

Persekitaran pembangunan bersepadu PHP yang berkuasa

Dreamweaver CS6

Dreamweaver CS6

Alat pembangunan web visual

SublimeText3 versi Mac

SublimeText3 versi Mac

Perisian penyuntingan kod peringkat Tuhan (SublimeText3)

Apakah kelemahan debian openssl Apakah kelemahan debian openssl Apr 02, 2025 am 07:30 AM

OpenSSL, sebagai perpustakaan sumber terbuka yang digunakan secara meluas dalam komunikasi yang selamat, menyediakan algoritma penyulitan, kunci dan fungsi pengurusan sijil. Walau bagaimanapun, terdapat beberapa kelemahan keselamatan yang diketahui dalam versi sejarahnya, yang sebahagiannya sangat berbahaya. Artikel ini akan memberi tumpuan kepada kelemahan umum dan langkah -langkah tindak balas untuk OpenSSL dalam sistem Debian. Debianopenssl yang dikenal pasti: OpenSSL telah mengalami beberapa kelemahan yang serius, seperti: Kerentanan Pendarahan Jantung (CVE-2014-0160): Kelemahan ini mempengaruhi OpenSSL 1.0.1 hingga 1.0.1f dan 1.0.2 hingga 1.0.2 versi beta. Penyerang boleh menggunakan kelemahan ini untuk maklumat sensitif baca yang tidak dibenarkan di pelayan, termasuk kunci penyulitan, dll.

Bagaimana anda menggunakan alat PPROF untuk menganalisis prestasi GO? Bagaimana anda menggunakan alat PPROF untuk menganalisis prestasi GO? Mar 21, 2025 pm 06:37 PM

Artikel ini menerangkan cara menggunakan alat PPROF untuk menganalisis prestasi GO, termasuk membolehkan profil, mengumpul data, dan mengenal pasti kesesakan biasa seperti CPU dan isu memori.

Bagaimana anda menulis ujian unit di GO? Bagaimana anda menulis ujian unit di GO? Mar 21, 2025 pm 06:34 PM

Artikel ini membincangkan ujian unit menulis di GO, meliputi amalan terbaik, teknik mengejek, dan alat untuk pengurusan ujian yang cekap.

Perpustakaan apa yang digunakan untuk operasi nombor terapung di GO? Perpustakaan apa yang digunakan untuk operasi nombor terapung di GO? Apr 02, 2025 pm 02:06 PM

Perpustakaan yang digunakan untuk operasi nombor terapung dalam bahasa Go memperkenalkan cara memastikan ketepatannya ...

Apakah masalah dengan thread giliran di crawler colly go? Apakah masalah dengan thread giliran di crawler colly go? Apr 02, 2025 pm 02:09 PM

Masalah Threading Giliran di GO Crawler Colly meneroka masalah menggunakan Perpustakaan Colly Crawler dalam bahasa Go, pemaju sering menghadapi masalah dengan benang dan permintaan beratur. � ...

Apakah arahan Go FMT dan mengapa ia penting? Apakah arahan Go FMT dan mengapa ia penting? Mar 20, 2025 pm 04:21 PM

Artikel ini membincangkan perintah Go FMT dalam pengaturcaraan GO, yang format kod untuk mematuhi garis panduan gaya rasmi. Ia menyoroti kepentingan GO FMT untuk mengekalkan konsistensi kod, kebolehbacaan, dan mengurangkan perdebatan gaya. Amalan terbaik untuk

Berubah dari front-end ke pembangunan back-end, adakah lebih menjanjikan untuk belajar Java atau Golang? Berubah dari front-end ke pembangunan back-end, adakah lebih menjanjikan untuk belajar Java atau Golang? Apr 02, 2025 am 09:12 AM

Laluan Pembelajaran Backend: Perjalanan Eksplorasi dari Front-End ke Back-End sebagai pemula back-end yang berubah dari pembangunan front-end, anda sudah mempunyai asas Nodejs, ...

Kaedah Pemantauan PostgreSQL di bawah Debian Kaedah Pemantauan PostgreSQL di bawah Debian Apr 02, 2025 am 07:27 AM

Artikel ini memperkenalkan pelbagai kaedah dan alat untuk memantau pangkalan data PostgreSQL di bawah sistem Debian, membantu anda memahami pemantauan prestasi pangkalan data sepenuhnya. 1. Gunakan PostgreSQL untuk membina pemantauan PostgreSQL sendiri menyediakan pelbagai pandangan untuk pemantauan aktiviti pangkalan data: PG_STAT_ACTIVITY: Memaparkan aktiviti pangkalan data dalam masa nyata, termasuk sambungan, pertanyaan, urus niaga dan maklumat lain. PG_STAT_REPLITI: Memantau status replikasi, terutamanya sesuai untuk kluster replikasi aliran. PG_STAT_DATABASE: Menyediakan statistik pangkalan data, seperti saiz pangkalan data, masa komitmen/masa rollback transaksi dan petunjuk utama lain. 2. Gunakan alat analisis log pgbadg

See all articles