Update hf-training-example.py

main
randaller 3 years ago committed by GitHub
parent c673bd2fc9
commit 312fa87dd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,12 +14,9 @@ OUTPUT_DIR = './trained'
texts = pd.read_csv(DATA_FILE_PATH)['text']
tokenizer = llamahf.LLaMATokenizer.from_pretrained(MODEL)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = llamahf.LLaMAForCausalLM.from_pretrained(MODEL).cpu()
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'})
model.resize_token_embeddings(len(tokenizer))
class TextDataset(Dataset):
def __init__(self, txt_list, tokenizer, max_length):

Loading…
Cancel
Save