Update hf-training-example.py

main
randaller 3 years ago committed by GitHub
parent 48ad5acd59
commit 5786267144
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,6 +4,9 @@ import pandas as pd
from torch.utils.data import Dataset, random_split from torch.utils.data import Dataset, random_split
from transformers import TrainingArguments, Trainer from transformers import TrainingArguments, Trainer
# # to save memory use bfloat16 on cpu
# torch.set_default_dtype(torch.bfloat16)
MODEL = 'decapoda-research/llama-7b-hf' MODEL = 'decapoda-research/llama-7b-hf'
DATA_FILE_PATH = 'datasets/elon_musk_tweets.csv' DATA_FILE_PATH = 'datasets/elon_musk_tweets.csv'
OUTPUT_DIR = './trained' OUTPUT_DIR = './trained'
@ -45,7 +48,7 @@ training_args = TrainingArguments(
logging_dir='./logs', logging_dir='./logs',
output_dir=OUTPUT_DIR, output_dir=OUTPUT_DIR,
no_cuda=True, no_cuda=True,
# bf16=True, bf16=True,
per_device_eval_batch_size=1, per_device_eval_batch_size=1,
per_device_train_batch_size=1) per_device_train_batch_size=1)

Loading…
Cancel
Save