Update hf-training-example.py

main
randaller 3 years ago committed by GitHub
parent a31651c9f0
commit b44bc2dd07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,7 +14,6 @@ OUTPUT_DIR = './trained'
texts = pd.read_csv(DATA_FILE_PATH)['text']
tokenizer = llamahf.LLaMATokenizer.from_pretrained(MODEL)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = llamahf.LLaMAForCausalLM.from_pretrained(MODEL).cpu()
@ -24,7 +23,8 @@ class TextDataset(Dataset):
self.input_ids = []
self.attn_masks = []
for txt in txt_list:
encodings_dict = tokenizer(txt, truncation=True, max_length=max_length, padding="max_length")
# encodings_dict = tokenizer(txt, truncation=True, max_length=max_length, padding="max_length")
encodings_dict = tokenizer(txt, truncation=True, max_length=max_length, pad_to_max_length=False)
self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
@ -45,7 +45,6 @@ training_args = TrainingArguments(
logging_dir='./logs',
output_dir=OUTPUT_DIR,
no_cuda=True,
# bf16=True,
per_device_eval_batch_size=1,
per_device_train_batch_size=1)

Loading…
Cancel
Save