Dari awal tahun hingga sekarang, AI generatif telah berkembang pesat. Tetapi banyak kali, kita perlu menghadapi masalah yang sukar: bagaimana untuk mempercepatkan latihan, penaakulan, dll. AI generatif, terutamanya apabila menggunakan PyTorch.
Dalam artikel ini, penyelidik dari pasukan PyTorch memberikan kami penyelesaian. Artikel ini memfokuskan pada cara menggunakan PyTorch asli tulen untuk mempercepatkan model AI generatif Ia juga memperkenalkan ciri PyTorch baharu dan contoh praktikal tentang cara menggabungkannya.
Apakah hasilnya? Pasukan PyTorch berkata mereka menulis semula model "Split Everything" (SAM) Meta, menghasilkan kod yang 8 kali lebih pantas daripada pelaksanaan asal tanpa kehilangan ketepatan, semuanya dioptimumkan menggunakan PyTorch asli.
Alamat blog: https://pytorch.org/blog/accelerating-generative-ai/
Selepas membaca artikel ini, anda akan mendapat pemahaman berikut:
.TorchKandungan akan mendalam lapisan demi lapisan . Pada akhir artikel ini, kami akan memperkenalkan SAM versi pantas. Untuk pembaca yang berminat, anda boleh memuat turunnya dari GitHub. Di samping itu, data ini divisualisasikan menggunakan Perfetto UI untuk menunjukkan nilai aplikasi pelbagai ciri PyTorch
Projek ini boleh didapati di alamat GitHub: https://github.com/pytorch-labs/segment-anything-fast Kod sumber
menulis semula split everything model SAM
Kajian menunjukkan bahawa jenis data garis dasar SAM yang digunakan dalam artikel ini ialah float32 dtype, saiz batch ialah 1, dan PyTorch Profiler digunakan untuk melihat hasil penjejakan teras Seperti berikut:
Yang pertama ialah panggilan panjang ke aten::index, yang disebabkan oleh operasi indeks tensor (seperti []) Disebabkan oleh panggilan asas yang dijana. Walau bagaimanapun, masa sebenar yang dibelanjakan oleh GPU untuk aten::index adalah agak rendah Sebabnya ialah semasa proses memulakan dua teras, aten::index menyekat cudaStreamSynchronize antara mereka. Ini bermakna CPU menunggu sehingga GPU selesai diproses sehingga teras kedua dilancarkan. Oleh itu, untuk mengoptimumkan SAM, kertas kerja ini percaya bahawa seseorang harus berusaha untuk menghapuskan penyegerakan GPU yang menyekat yang menyebabkan masa terbiar.
Masalah kedua ialah SAM menghabiskan banyak masa GPU dalam pendaraban matriks (bahagian hijau gelap seperti yang ditunjukkan dalam gambar), yang sangat biasa dalam model Transformers. Jika kita boleh mengurangkan masa GPU model SAM pada pendaraban matriks, maka kita boleh meningkatkan kelajuan SAM dengan ketara
Seterusnya, kita akan membandingkan daya tampung (img/s) dan overhed memori (GiB) SAM Wujudkan garis dasar. Kemudian terdapat proses pengoptimuman
Ayat yang perlu ditulis semula ialah: Bfloat16 separuh ketepatan (ditambah penyegerakan GPU dan pemprosesan batch) untuk menyelesaikan masalah di atas
ialah, kurangkan masa yang diperlukan untuk masa pendaraban matriks, artikel ini bertukar kepada bfloat16. bfloat16 ialah jenis separuh ketepatan yang biasa digunakan, yang boleh menjimatkan banyak masa dan memori pengkomputeran dengan mengurangkan ketepatan setiap parameter dan pengaktifan Selain itu, artikel ini dijumpai Terdapat dua tempat yang boleh dioptimumkan untuk membuang penyegerakan GPU
Secara khusus, lebih mudah difahami berdasarkan gambar di atas. pengekod imej SAM, Terdapat dua pembolehubah q_coords dan k_coords yang bertindak sebagai penimbang koordinat, dan pembolehubah ini diperuntukkan dan diproses pada CPU. Walau bagaimanapun, sebaik sahaja pembolehubah ini digunakan untuk mengindeks dalam rel_pos_resized, operasi pengindeksan mengalihkan pembolehubah ini secara automatik ke GPU, menyebabkan masalah penyegerakan GPU. Untuk menyelesaikan masalah ini, penyelidikan menunjukkan bahawa bahagian ini boleh diselesaikan dengan menulis semula menggunakan fungsi obor.where seperti yang ditunjukkan di atas
Penjejakan TerasSelepas menggunakan perubahan ini, kami dapati Terdapat perubahan yang ketara. jurang masa antara panggilan kernel individu, terutamanya dengan kelompok kecil (di sini 1). Untuk mendapatkan pemahaman yang lebih mendalam tentang fenomena ini, kami memulakan analisis prestasi inferens SAM dengan saiz kelompok 8
Semasa menganalisis masa yang dibelanjakan bagi setiap teras, kami mendapati bahawa majoriti GPU untuk Masa SAM adalah dibelanjakan untuk kernel mengikut unsur dan operasi softmax
Kini anda dapat melihat bahawa overhed relatif pendaraban matriks adalah jauh lebih kecil.
Menggabungkan penyegerakan GPU dan pengoptimuman bfloat16, prestasi SAM dipertingkatkan sebanyak 3 kali ganda.
Torch.compile (+pemecahan graf dan graf CUDA)
Saya menemui banyak operasi kecil semasa kajian SAM. Penyelidik percaya bahawa menggunakan pengkompil untuk menyatukan operasi ini adalah sangat berfaedah, jadi PyTorch membuat pengoptimuman berikut untuk torch.compile
Menyatukan urutan operasi seperti nn.LayerNorm atau nn.GELU menjadi satu kernel GPU
Kendalian fius serta-merta mengikuti kernel pendaraban matriks untuk mengurangkan bilangan panggilan kernel GPU.
Penjejakan teras Menurut keputusan, torch.compile berfungsi dengan sangat baik sehingga sebahagian besar masa , dan kemudian setiap varian GEMM. Ukuran berikut adalah untuk saiz kelompok 8 dan ke atas. SDPA: scaled_dot_product_attention Seterusnya, artikel ini menjalankan eksperimen pada SDPA (scaled_dot_product_attention), memfokuskan pada mekanisme perhatian. Secara umum, mekanisme perhatian asli berskala kuadratik dengan panjang jujukan dalam masa dan ingatan. Operasi SDPA PyTorch dibina berdasarkan prinsip perhatian cekap memori Flash Attention, FlashAttentionV2 dan xFormer, yang boleh mempercepatkan perhatian GPU dengan ketara. Digabungkan dengan torch.compile, operasi ini membenarkan ekspresi dan gabungan corak biasa dalam varian MultiheadAttention. Selepas perubahan kecil, model kini boleh menggunakan perhatian_produk_berskala. Penjejakan Teras Anda kini boleh melihat kernel perhatian yang cekap memori mengambil banyak masa pengiraan pada GPU: skala_titik _perhatian_produk, ya Meningkatkan saiz kelompok dengan ketara. Graf di bawah menunjukkan perubahan untuk saiz kelompok 32 dan ke atas. Seterusnya, kajian itu menjalankan eksperimen ke atas Triton, NestedTensor, batch Predict_torch, kuantisasi int8, sparsity separa berstruktur (2:4) dan operasi lain Sebagai contoh, artikel ini menggunakan a
Penghujung artikel ialah sparsity separa berstruktur. Kajian menunjukkan bahawa pendaraban matriks masih menjadi hambatan yang perlu dihadapi. Penyelesaiannya adalah dengan menggunakan sparsifikasi untuk menganggarkan pendaraban matriks. Dengan matriks yang jarang (iaitu mensifarkan nilai) lebih sedikit bit boleh digunakan untuk menyimpan pemberat dan tensor pengaktifan. Proses menetapkan pemberat dalam tensor yang ditetapkan kepada sifar dipanggil pemangkasan. Pemangkasan berat yang lebih kecil berpotensi mengurangkan saiz model tanpa kehilangan ketepatan yang ketara. Terdapat banyak cara untuk memangkas, bermula daripada tidak berstruktur sepenuhnya kepada berstruktur tinggi. Walaupun pemangkasan tidak berstruktur secara teorinya mempunyai kesan minimum pada ketepatan, dalam kes yang jarang GPU mungkin mengalami kemerosotan prestasi yang ketara walaupun sangat cekap apabila melakukan pendaraban matriks padat yang besar. Satu kaedah pemangkasan baru-baru ini disokong oleh PyTorch ialah sparsity separa berstruktur (atau 2:4), yang bertujuan untuk mencari keseimbangan. Kaedah penyimpanan jarang ini mengurangkan tensor asal sebanyak 50% sambil menghasilkan output tensor padat. Sila rujuk rajah di bawah untuk penjelasan Untuk menggunakan format storan jarang ini dan kernel pantas yang berkaitan, langkah seterusnya Apa yang dilakukannya ialah pemberat prun. Artikel ini memilih dua pemberat terkecil untuk pemangkasan pada kesederhanaan 2:4 Menukar pemberat daripada reka letak lalai PyTorch ("strided") kepada reka letak jarang separa berstruktur ini adalah mudah. Untuk melaksanakan apply_sparse (model), hanya 32 baris kod Python diperlukan: pada jarak 2: 4, kami memerhatikan prestasi puncak SAM dengan vit_b dan saiz kelompok 32 Ringkasan artikel adalah seperti berikut: Artikel ini memperkenalkan cara terpantas untuk melaksanakan Segmen Apa-apa pada PyTorch setakat ini dengan satu siri ciri baharu yang dikeluarkan secara rasmi artikel menulis semula SAM asal dalam PyTorch tulen tanpa kehilangan ketepatan #🎜🎜 # Untuk pembaca yang berminat, anda boleh menyemak blog asal untuk maklumat lanjut
Atas ialah kandungan terperinci Pasukan PyTorch melaksanakan semula model 'split everything' lapan kali lebih pantas daripada pelaksanaan asal. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!