Rumah > pembangunan bahagian belakang > Tutorial Python > Contoh kod penalaan halus Huggingface BART: Set data WMT16 untuk melatih teg baharu untuk terjemahan

Contoh kod penalaan halus Huggingface BART: Set data WMT16 untuk melatih teg baharu untuk terjemahan

王林
Lepaskan: 2023-04-10 14:41:06
ke hadapan
1372 orang telah melayarinya

Jika anda ingin menguji seni bina baharu pada tugas terjemahan, seperti melatih teg baharu pada set data tersuai, ia akan menyusahkan untuk dikendalikan, jadi dalam artikel ini, saya akan memperkenalkan pra-pemprosesan menambah teg baharu. Memproses langkah dan memperkenalkan cara memperhalusi model.

Oleh kerana Huggingface Hub mempunyai banyak model terlatih, mudah untuk mencari penanda terlatih. Tetapi mungkin agak sukar untuk menambah penanda Mari kita perkenalkan sepenuhnya cara melaksanakannya Mula-mula, muatkan dan praproses set data.

Memuatkan set data

Kami menggunakan set data WMT16 dan subset Romania-Inggerisnya. Fungsi load_dataset() akan memuat turun dan memuatkan mana-mana set data yang tersedia daripada Huggingface.

import datasets
 
 dataset = datasets.load_dataset("stas/wmt16-en-ro-pre-processed", cache_dir="./wmt16-en_ro")
Salin selepas log masuk

Contoh kod penalaan halus Huggingface BART: Set data WMT16 untuk melatih teg baharu untuk terjemahan

Kandungan set data boleh dilihat dalam Rajah 1 di atas. Kita perlu "meratakannya" supaya kita boleh mengakses data dengan lebih baik dan menyimpannya ke cakera keras.

def flatten(batch):
 batch['en'] = batch['translation']['en']
 batch['ro'] = batch['translation']['ro']
 
 return batch
 
 # Map the 'flatten' function
 train = dataset['train'].map( flatten )
 test = dataset['test'].map( flatten )
 validation = dataset['validation'].map( flatten )
 
 # Save to disk
 train.save_to_disk("./dataset/train")
 test.save_to_disk("./dataset/test")
 validation.save_to_disk("./dataset/validation")
Salin selepas log masuk

Seperti yang anda lihat dalam Rajah 2 di bawah, dimensi "terjemahan" telah dipadamkan daripada set data.

Contoh kod penalaan halus Huggingface BART: Set data WMT16 untuk melatih teg baharu untuk terjemahan

Tagger

Tagger menyediakan semua kerja yang diperlukan untuk melatih tokenizer. Ia terdiri daripada empat komponen asas: (tetapi bukan keempat-empatnya diperlukan)

Model: Bagaimana tokenizer akan memecahkan setiap perkataan. Sebagai contoh, diberi perkataan "bermain": i) Model BPE menguraikannya kepada dua token "bermain" + "ing", ii) WordLevel menganggapnya sebagai satu token.

Normalizers: Beberapa transformasi yang perlu berlaku pada teks. Terdapat penapis untuk menukar Unicode, huruf kecil atau mengalih keluar kandungan.

Pra-Tokenizer: Fungsi yang memberikan fleksibiliti yang lebih besar untuk mengendalikan teks. Sebagai contoh, bagaimana untuk bekerja dengan nombor. Sekiranya nombor 100 dianggap "100" atau "1", "0", "0"?

Pos-Pemproses: Spesifik pasca pemprosesan bergantung pada pilihan pra -model terlatih. Sebagai contoh, tambah token [BOS] (permulaan ayat) atau [EOS] (akhir ayat) pada input BERT.

Kod di bawah menggunakan model BPE, Penormal huruf kecil dan Pra-Tokenizer kosong. Kemudian mulakan objek pelatih dengan nilai lalai, terutamanya termasuk

1 Gunakan 50265 untuk saiz perbendaharaan kata agar selaras dengan penanda Bahasa Inggeris BART

2 dan , 3. Perbendaharaan kata awal, ini ialah senarai yang telah ditetapkan untuk setiap proses permulaan model.

Langkah terakhir dalam menggunakan Huggingface ialah menyambungkan model Pelatih dan BPE dan menghantar set data. Bergantung pada sumber data, fungsi latihan yang berbeza boleh digunakan. Kami akan menggunakan train_from_iterator().
from tokenizers import normalizers, pre_tokenizers, Tokenizer, models, trainers
 
 # Build a tokenizer
 bpe_tokenizer = Tokenizer(models.BPE())
 bpe_tokenizer.normalizer = normalizers.Lowercase()
 bpe_tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
 
 trainer = trainers.BpeTrainer(
 vocab_size=50265,
 special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"],
 initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
 )
Salin selepas log masuk

BART Spinner
def batch_iterator():
 batch_length = 1000
 for i in range(0, len(train), batch_length):
 yield train[i : i + batch_length]["ro"]
 
 bpe_tokenizer.train_from_iterator( batch_iterator(), length=len(train), trainer=trainer )
 
 bpe_tokenizer.save("./ro_tokenizer.json")
Salin selepas log masuk

Tagger baharu kini tersedia.

Pada baris 5 kod di atas, adalah perlu untuk menetapkan teg pelapik untuk penanda Romania. Memandangkan ia akan digunakan pada baris 9, tokenizer menggunakan padding supaya semua input adalah saiz yang sama.
from transformers import AutoTokenizer, PreTrainedTokenizerFast
 
 en_tokenizer = AutoTokenizer.from_pretrained( "facebook/bart-base" );
 ro_tokenizer = PreTrainedTokenizerFast.from_pretrained( "./ro_tokenizer.json" );
 ro_tokenizer.pad_token = en_tokenizer.pad_token
 
 def tokenize_dataset(sample):
 input = en_tokenizer(sample['en'], padding='max_length', max_length=120, truncation=True)
 label = ro_tokenizer(sample['ro'], padding='max_length', max_length=120, truncation=True)
 
 input["decoder_input_ids"] = label["input_ids"]
 input["decoder_attention_mask"] = label["attention_mask"]
 input["labels"] = label["input_ids"]
 
 return input
 
 train_tokenized = train.map(tokenize_dataset, batched=True)
 test_tokenized = test.map(tokenize_dataset, batched=True)
 validation_tokenized = validation.map(tokenize_dataset, batched=True)
Salin selepas log masuk

Berikut ialah proses latihan:

Proses ini juga sangat mudah Muatkan model asas bart (baris 4), tetapkan parameter latihan (baris 6), dan gunakan objek Jurulatih untuk mengikat segala-galanya (baris 22), dan memulakan proses (baris 29). Hiperparameter di atas adalah untuk tujuan ujian, jadi jika anda ingin mendapatkan hasil yang terbaik, anda perlu menetapkan hiperparameter Kami boleh menjalankan menggunakan parameter ini.
from transformers import BartForConditionalGeneration
 from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
 
 model = BartForConditionalGeneration.from_pretrained("facebook/bart-base" )
 
 training_args = Seq2SeqTrainingArguments(
 output_dir="./",
 evaluation_strategy="steps",
 per_device_train_batch_size=2,
 per_device_eval_batch_size=2,
 predict_with_generate=True,
 logging_steps=2,# set to 1000 for full training
 save_steps=64,# set to 500 for full training
 eval_steps=64,# set to 8000 for full training
 warmup_steps=1,# set to 2000 for full training
 max_steps=128, # delete for full training
 overwrite_output_dir=True,
 save_total_limit=3,
 fp16=False, # True if GPU
 )
 
 trainer = Seq2SeqTrainer(
 model=model,
 args=training_args,
 train_dataset=train_tokenized,
 eval_dataset=validation_tokenized,
 )
 
 trainer.train()
Salin selepas log masuk

Inferens

Proses inferens juga sangat mudah. ​​Cuma muatkan model yang diperhalusi dan gunakan kaedah jana() untuk menukarnya, anda perlu memberi perhatian kepada sumbernya. En) dan sasaran (RO) dengan menggunakan tokenizer yang sesuai.

Ringkasan

Walaupun tokenisasi mungkin kelihatan seperti operasi asas apabila menggunakan pemprosesan bahasa semula jadi (NLP), ia merupakan langkah kritikal yang tidak boleh diabaikan. Kemunculan HuggingFace memudahkan kami menggunakannya, yang memudahkan kami melupakan prinsip asas tokenisasi dan hanya bergantung pada model yang telah dilatih. Tetapi apabila kita ingin melatih model baharu sendiri, memahami proses tokenisasi dan kesannya terhadap tugas hiliran adalah penting, jadi adalah perlu untuk membiasakan diri dan menguasai operasi asas ini.

Kod artikel ini: https://github.com/AlaFalaki/tutorial_notebooks/blob/main/translation/hf_bart_translation.ipynb

Atas ialah kandungan terperinci Contoh kod penalaan halus Huggingface BART: Set data WMT16 untuk melatih teg baharu untuk terjemahan. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Label berkaitan:
sumber:51cto.com
Kenyataan Laman Web ini
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn
Tutorial Popular
Lagi>
Muat turun terkini
Lagi>
kesan web
Kod sumber laman web
Bahan laman web
Templat hujung hadapan