Baru-baru ini, program AI chatbot ChatGPT yang dibangunkan oleh OpenAI telah melanda komuniti AI utama. Semangat semua orang terhadapnya semakin meningkat, dan mereka terus memanfaatkan potensinya.
Sesetengah penyelidik tidak boleh duduk diam dan mula tertanya-tanya bagaimana untuk membangunkan perisian sumber terbuka yang setara dengan ChatGPT. Bagi mereka yang masih belum mengambil tindakan, berikut adalah contoh rujukan kali ini Projek (PaLM + RLHF) yang akan kami perkenalkan di bawah melaksanakan fungsi tersebut.
Alamat projek: https://github.com/lucidrains/PaLM-rlhf-pytorch
Projek ini adalah untuk melaksanakan RLHF di atas PaLM seni bina ( pembelajaran pengukuhan maklum balas manusia). Pada asasnya sama seperti ChatGPT, perbezaannya ialah PaLM digunakan. PaLM ialah model bahasa yang besar dengan 540 bilion parameter yang dilatih pada seni bina AI umum Google "Pathways". RLHF ialah pengenalan ChatGPT tentang "data berlabel buatan + pembelajaran pengukuhan" (RLHF) berdasarkan model siri GPT 3.5 untuk memperhalusi model bahasa pra-latihan secara berterusan, bertujuan untuk membolehkan model bahasa besar (LLM) belajar untuk memahami arahan manusia dan Belajar untuk memberikan jawapan yang optimum berdasarkan gesaan yang diberikan.
Kalau nak tahu lebih lanjut tentang RLHF, boleh rujuk: https://huggingface.co/blog/rlhf
Bak kata seorang netizen: "Dalam bidang AI, setiap apabila terdapat projek khas Terobosan, pembangun akan mengeluarkan semula versi sumber terbuka tidak lama lagi "
Walau bagaimanapun, projek itu pada masa ini. hanya mengandungi seni bina dan kod Latihan, tiada pemberat pra-latihan. Dalam arahan penggunaan, dokumen itu juga menunjukkan bahawa PaLM mesti dilatih terlebih dahulu.
Beberapa netizen turut menyatakan kebimbangan mengenai perkara ini, dengan berkata: Ini bukan projek di luar kotak, ia hanya struktur, seperti cangkerang, ia memerlukan overhed yang mahal untuk melatih Tiada organisasi boleh melatih PaLM seperti Google.
Sesetengah netizen berkata: "Sangat teruk jika tidak mempunyai pemberat yang telah dilatih. Pegawai perlu melepaskan sekurang-kurangnya 50% daripada pemberat yang jarang, dan biarkan pemaju berlatih selebihnya dengan sendirinya. Ini adalah pilihan terbaik.”
Namun, sesetengah netizen berkata mereka akan mencubanya:
Mari kita lihat bagaimana projek ini berfungsi.
$ pip install palm-rlhf-pytorch
Latih PaLM pertama, sama seperti pengubah autoregresif lain.
import torch from palm_rlhf_pytorch import PaLM palm = PaLM( num_tokens = 20000, dim = 512, depth = 12 ).cuda() seq = torch.randint(0, 20000, (1, 2048)).cuda() loss = palm(seq, return_loss = True)loss.backward() # after much training, you can now generate sequences generated = palm.generate(2048) # (1, 2048)
Model ganjaran kemudiannya dilatih menggunakan maklum balas manusia yang dipilih susun. Dalam kertas asal, adalah tidak mungkin untuk mendapatkan model ganjaran yang diperhalusi daripada pengubah terlatih tanpa pemasangan berlebihan. Pengarang projek menyediakan pilihan untuk menggunakan LoRA untuk penalaan halus.
import torch from palm_rlhf_pytorch import PaLM, RewardModel palm = PaLM( num_tokens = 20000, dim = 512, depth = 12, causal = False ) reward_model = RewardModel( palm, num_binned_output = 5 # say rating from 1 to 5 ).cuda() # mock data seq = torch.randint(0, 20000, (1, 1024)).cuda()prompt_mask = torch.zeros(1, 1024).bool().cuda() # which part of the sequence is prompt, which part is response labels = torch.randint(0, 5, (1,)).cuda() # train loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)loss.backward() # after much training reward = reward_model(seq, prompt_mask = prompt_mask)
Akhir sekali, hantarkan pengubah dan model ganjaran kepada RLHFTrainer.
import torch from palm_rlhf_pytorch import PaLM, RewardModel, RLHFTrainer # load your pretrained palm palm = PaLM( num_tokens = 20000, dim = 512, depth = 12 ).cuda() palm.load('./path/to/pretrained/palm.pt') # load your pretrained reward model reward_model = RewardModel( palm, num_binned_output = 5 ).cuda() reward_model.load('./path/to/pretrained/reward_model.pt') # ready your list of prompts for reinforcement learning prompts = torch.randint(0, 256, (50000, 512)).cuda() # 50k prompts # pass it all to the trainer and train trainer = RLHFTrainer( palm = palm, reward_model = reward_model, prompt_token_ids = prompts ) trainer.train(num_episodes = 50000) # then, if it succeeded... # generate say 10 samples and use the reward model to return the best one answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,)
Atas ialah kandungan terperinci Cukup pantas! Projek sumber terbuka setara ChatGPT yang popular ada di sini, netizen: Saya bimbang saya tidak dapat menjalankannya. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!