Meta FAIR digabungkan dengan Harvard untuk menyediakan rangka kerja penyelidikan baharu untuk mengoptimumkan bias data yang disebabkan oleh pembelajaran mesin berskala besar.
Seperti yang kita sedia maklum, latihan model bahasa besar selalunya mengambil masa berbulan-bulan dan menggunakan ratusan malah beribu-ribu GPU. Mengambil model LLaMA2 70B sebagai contoh, latihannya memerlukan sejumlah 1,720,320 jam GPU. Melatih model besar memberikan cabaran sistemik yang unik disebabkan oleh skala dan kerumitan beban kerja ini.
Baru-baru ini, banyak institusi telah melaporkan ketidakstabilan semasa proses latihan semasa melatih model AI generatif SOTA biasanya muncul dalam bentuk lonjakan kerugian, seperti model PaLM Google, yang berlaku sehingga 20 kali semasa proses latihan pancang.
Sisihan berangka adalah punca ketidaktepatan latihan ini Disebabkan oleh kos pelaksanaan latihan model bahasa yang sangat tinggi, cara mengira sisihan berangka telah menjadi isu utama.
Dalam karya terbaru, penyelidik dari Meta dan Universiti Harvard membangunkan kaedah kuantitatif berprinsip untuk memahami bias berangka dalam pengoptimuman latihan. Ini digunakan untuk menilai teknik pengoptimuman tercanggih yang berbeza dan menentukan sama ada teknik tersebut mungkin menimbulkan ketidakstabilan yang tidak dijangka apabila digunakan untuk melatih model besar. Para penyelidik mendapati bahawa walaupun kaedah pengoptimuman sedia ada berfungsi dengan baik pada beberapa tugas, beberapa sisihan berangka berlaku apabila digunakan pada model besar. Kecondongan berangka ini boleh mewujudkan ketidakstabilan semasa proses latihan, menyebabkan prestasi model merosot. Untuk menyelesaikan masalah ini, penyelidik mencadangkan pengoptimuman berdasarkan kaedah kuantitatif berprinsip
Didapati bahawa dalam satu hantaran hadapan, sisihan berangka Perhatian Kilat adalah susunan magnitud yang lebih besar daripada Perhatian Garis Dasar BF16.
Secara khusus, kaedah ini terdiri daripada dua peringkat, termasuk:
Para penyelidik menganalisis teknologi pengoptimuman SOTA Flash Attention dan mengukur sisihan berangka yang mungkin diperkenalkan. Flash Attention ialah teknologi yang digunakan secara meluas untuk mempercepatkan mekanisme perhatian dan sering dianggap sebagai kesesakan sistem dalam model Transformer. Walaupun Flash Attention meningkatkan kelajuan dan mengurangkan akses memori, ia juga bergantung pada pengoptimuman algoritma, dan pengoptimuman algoritma boleh membawa kepada peningkatan sisihan berangka.
Para penyelidik membuat hipotesis bahawa menambahkan faktor penskalaan semula mungkin memperkenalkan anggaran yang tidak disengajakan, yang membawa kepada pertukaran berangka, yang kemudiannya boleh menjejaskan kestabilan latihan.
Mereka menganalisis Flash Attention dalam konteks beban kerja teks-ke-imej berbilang mod untuk menentukan potensi kepentingan sisihan berangka antara Flash Attention dan garis dasarnya. Akhirnya, mereka memperkenalkan rangka kerja untuk mengukur bias berangka pengoptimuman latihan dan kesan hilirannya.
Penyelidik telah membuat dua sumbangan utama berikut dalam mengukur sisihan berangka:
Penanda aras mikro yang direka oleh penyelidik ialah teknik yang digunakan untuk mengukur dan mengukur sisihan berangka yang disebabkan oleh pengoptimuman kotak hitam tradisional (seperti Flash Attention). Dengan mengganggu aspek yang biasanya tidak tersedia dalam kernel yang disediakan, mereka mempelopori penemuan bahawa pada ketepatan berangka yang rendah (BF16), Flash Attention mempunyai lebih kurang susunan magnitud bias berangka yang lebih tinggi berbanding dengan Baseline Attention.
Melalui analisis ini, penyelidik mengkontekstualisasikan sisihan berangka yang diperhatikan dan membentuk had atas kesannya terhadap sifat model hiliran. Dalam kajian kes penyelidik, mereka dapat mengehadkan kesan bias berangka yang diperhatikan dan mendapati: "Flash Attention memperkenalkan bias berat model yang kira-kira 1/2 hingga 1/5 kali ganda daripada latihan ketepatan rendah
Kajian ini menyerlahkan kepentingan membangunkan pendekatan berprinsip untuk "bukan sahaja mengukur tetapi juga mengkontekstualisasikan kesan pengoptimuman latihan ke atas bias berangka." Dengan membina proksi untuk mengkontekstualisasikan konteks berat sebelah berangka, bertujuan untuk membuat kesimpulan kemungkinan kesan model hiliran (iaitu. , ketidakstabilan latihan) yang selalunya sukar diukur.
Para penyelidik mula-mula membangunkan penanda aras mikro untuk mengasingkan dan mengkaji sisihan berangka yang disebabkan oleh Flash Attention. Seperti yang ditunjukkan dalam Rajah 2, mereka melaksanakan semula Flash Attention secara berangka untuk menganalisis ketepatan berangka yang berbeza dan menggunakan langkah pengoptimuman yang berpotensi pada setiap langkah algoritma.
Rajah 2: Ringkasan reka bentuk penanda aras mikro.
Ini perlu kerana teras Flash Attention pada masa ini hanya menyokong format berangka FP16 dan BF16. Kernel ini juga merupakan panggilan API pembalut untuk kod CUDA, yang menjadikannya mencabar untuk mengganggu algoritma untuk mengkaji kesan bias berangka.
Sebaliknya, reka bentuk penanda aras mikro mereka membenarkan input ketepatan dan pengubahsuaian dalam algoritma. Para penyelidik mengesahkan penanda aras mikro terhadap kernel Flash Attention yang asal.
Mereka seterusnya mereka bentuk teknik untuk membandingkan output matriks Perhatian pada setiap langkah semasa pelaksanaan model. Dan mengubah suai kod model untuk mengira Perhatian Garis Dasar dan Perhatian Kilat setiap kali perhatian dipanggil, yang membolehkan perbandingan matriks keluaran yang tepat untuk matriks input yang sama.
Untuk meletakkan ini dalam konteks, kami juga menggunakan perbezaan Max dan metrik Jarak Wasserstein untuk mengukur perbezaan berat model sepanjang latihan menggunakan larian latihan yang sama dan bebas.
Untuk percubaan latihan, penyelidik menggunakan beban kerja AI generatif (iaitu model teks-ke-imej) yang menukar input teks kepada imej. Mereka melatih semula model menggunakan set data Shutterstock dan menjalankan eksperimen pada kelompok GPU NVIDIA 80GB A100.
Para penyelidik mula-mula menganalisis kesan Perhatian Kilat dalam proses hantaran hadapan. Mereka menggunakan penanda aras mikro untuk mengkaji kesan ketepatan berangka yang berbeza pada matriks keluaran yang dikira oleh Perhatian, di bawah syarat bahawa pertanyaan, kunci dan vektor nilai yang dimulakan secara rawak adalah sama.
Seperti yang ditunjukkan dalam Rajah 3, apabila penyelidik menggunakan format berangka yang berbeza antara BF16 hingga FP64, sisihan berangka antara Flash Attention dan Baseline Attention berkurangan apabila bilangan digit mantissa meningkat. Ini menunjukkan bahawa perbezaan berangka adalah disebabkan oleh anggaran yang wujud dalam mempunyai angka mantissa yang lebih sedikit.
Rajah 3: Kesan format berangka pada sisihan berangka Flash Attention.
Selepas itu, penyelidik menetapkan "nilai emas" untuk Perhatian Baseline dalam format berangka FP64 untuk perbandingan standard, dan kemudian membandingkan output Perhatian dalam format berangka yang berbeza dengan nilai ini (seperti yang ditunjukkan dalam Rajah 4).
Rajah 4: Perbandingan "nilai emas" Perhatian Baseline di bawah FP64.
Keputusan menunjukkan bahawa sisihan berangka Perhatian Kilat adalah kira-kira 10 kali ganda daripada Garis Dasar di bawah BF16.
Untuk menganalisis selanjutnya sisihan berangka yang diperhatikan ini, penyelidik mengimbas panjang jujukan matriks sambil mengekalkan saiz jubin dan saiz SRAM yang tetap (seperti yang ditunjukkan dalam Rajah 5).
Rajah 5: Kesan panjang jujukan pada sisihan berangka Perhatian Denyar.
Seperti yang ditunjukkan dalam rajah, apabila panjang jujukan bertambah, sama ada diukur dengan (a) had atas perbezaan maksimum, atau dengan (b) min dan sisihan piawai perbezaan, perbezaan antara Perhatian Denyar dan Garis Dasar Perhatian Sisihan berangka semakin meningkat.
Selain itu, penyelidik juga menggunakan reka bentuk penanda aras mikro untuk menjalankan eksperimen dengan pengoptimuman berbeza untuk lebih memahami kesan sisihan berangka (seperti ditunjukkan dalam Rajah 6).
Rajah 6a menunjukkan cara menukar susunan dimensi blok menyebabkan perbezaan berangka antara Perhatian Kilat dan Perhatian Garis Dasar meningkat. Gangguan lain dalam Rajah 6b, seperti mengehadkan saiz jubin kepada segi empat sama, tidak mempunyai kesan ke atas bias berangka. Rajah 6c menunjukkan bahawa lebih besar saiz blok/jubin, lebih kecil sisihan berangka.
Rajah 6: Perubahan algoritma dan kesannya terhadap sisihan berangka yang diperhatikan.
Walaupun Flash Attention boleh menyebabkan bias berangka dalam output Attention semasa hantaran hadapan, matlamat utama kajian ini adalah untuk menentukan sama ada ini akan berlaku semasa latihan model menghasilkan sebarang kesan untuk menyiasat sama ada ia menyumbang kepada ketidakstabilan latihan.
Oleh itu, penyelidik berharap dapat mengukur sama ada Flash Attention menukar model semasa latihan, iaitu sama ada perbezaan dalam output Perhatian yang diperhatikan di atas ditunjukkan dalam berat model yang dikemas kini semasa latihan.
Para penyelidik menggunakan dua penunjuk untuk mengukur perbezaan berat model antara model yang dilatih menggunakan Baseline Attention dan model yang dilatih menggunakan Flash Attention. Perbezaan maksimum pertama kali dikira, iaitu, mencari nilai mutlak perbezaan antara matriks berat dan mengambil nilai maksimum, dengan itu memperoleh had atas sisihan, seperti berikut:
Manakala perbezaan maksimum memberikan had atas sisihan berangka, Tetapi ia tidak mengambil kira taburan setiap matriks. Oleh itu, penyelidik mengira perbezaan berat melalui Jarak Wasserstein, yang merupakan ukuran persamaan yang biasa antara tensor. Walaupun sedikit lebih kompleks dari segi pengiraan, Jarak Wasserstein termasuk maklumat bentuk taburan tensor untuk mengukur persamaan. Formula pengiraan diringkaskan seperti berikut:
Semakin rendah nilai, semakin tinggi persamaan antara matriks.
Menggunakan kedua-dua metrik ini, penyelidik kemudian mengukur bagaimana berat model Perhatian Kilat berubah berbanding Perhatian Baseline sepanjang proses latihan:
Mengikut penunjuk Wasserstein, Untuk Dua Perbezaan Maksimum ini keseluruhan proses latihan, penambahan Flash Attention memang mengubah berat model, dan apabila latihan diteruskan, perbezaan ini hanya akan menjadi lebih besar dan lebih besar Ini menunjukkan bahawa model yang dilatih menggunakan Flash Attention adalah berbeza daripada model yang dilatih menggunakan Baseline Attention. Model yang sama dilatih menumpu kepada model yang berbeza.
Walau bagaimanapun, latihan ialah proses stokastik, dan perubahan tertentu dalam struktur model mungkin menghasilkan keputusan yang sama dari segi kesan hiliran dan ketepatan. Ini patut diberi perhatian walaupun berat model yang dilatih dengan Flash Attention dan Baseline Attention adalah berbeza.
Melatih model sepenuhnya dan menilai ketepatan adalah tugas yang mahal dan memerlukan sumber, terutamanya untuk model besar yang mengambil masa berbulan-bulan untuk dilatih.
Penyelidik mengkonfigurasi proksi untuk meneroka:
(a) Sejauh manakah perubahan berat badan ini ketara?
(b) Bolehkah ini dikaitkan dengan perubahan berat standard dalam pengoptimuman latihan lain yang diterima pakai secara meluas?
Untuk mencapai matlamat ini, penyelidik mereka satu siri eksperimen untuk membandingkan bagaimana perbezaan berat badan berubah semasa proses latihan di bawah senario yang berbeza.
Selain membandingkan proses latihan menggunakan Flash Attention dan Baseline Attention, mereka juga mengukur perbezaan berat semasa proses latihan yang sama di mana pemberat dimulakan kepada nilai rawak yang berbeza pada permulaan latihan. Ini memberikan batasan, kerana permulaan berat rawak adalah teknik biasa dan selalunya menghasilkan hasil yang setara.
Selain itu, penyelidik juga mengukur perubahan dalam berat model yang dilatih dengan ketepatan yang berbeza. Ketepatan berangka (iaitu, FP16 lwn. FP32) berpotensi menyebabkan perubahan hiliran, yang berfungsi sebagai sempadan atas kepentingan pemberat Perhatian Flash.
Seperti yang ditunjukkan dalam Rajah 8, boleh didapati bahawa kadar perubahan berat sebelah berat model menggunakan Flash Attention adalah setanding atau lebih kecil daripada kadar perubahan bias berat bagi permulaan model yang berbeza (perhatikan cerun lengkung merah dan biru) .
Selain itu, kadar perubahan berat apabila menggunakan FP16 vs. FP32 adalah lebih tinggi dan perubahan lebih besar daripada apabila model berbeza dimulakan.
Keputusan ini memberikan proksi dan menunjukkan: "Walaupun Perhatian Flash akan mengalami bias berangka, ia akan dihadkan oleh permulaan model rawak dan latihan ketepatan rendah. Dan berat badan model yang diperkenalkan adalah kira-kira 10% apabila latihan dengan rendah ketepatan. 1/2 hingga 1/5 kali ganda.
Untuk butiran penyelidikan lanjut, sila rujuk kertas asal.
Atas ialah kandungan terperinci Adakah Flash Attention stabil? Meta dan Harvard mendapati bahawa sisihan berat model mereka berubah-ubah mengikut urutan magnitud. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!