|
|
|
|
@ -4,6 +4,9 @@ import pandas as pd
|
|
|
|
|
from torch.utils.data import Dataset, random_split
|
|
|
|
|
from transformers import TrainingArguments, Trainer
|
|
|
|
|
|
|
|
|
|
# # to save memory use bfloat16 on cpu
|
|
|
|
|
# torch.set_default_dtype(torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
MODEL = 'decapoda-research/llama-7b-hf'
|
|
|
|
|
DATA_FILE_PATH = 'datasets/elon_musk_tweets.csv'
|
|
|
|
|
OUTPUT_DIR = './trained'
|
|
|
|
|
@ -45,7 +48,7 @@ training_args = TrainingArguments(
|
|
|
|
|
logging_dir='./logs',
|
|
|
|
|
output_dir=OUTPUT_DIR,
|
|
|
|
|
no_cuda=True,
|
|
|
|
|
# bf16=True,
|
|
|
|
|
bf16=True,
|
|
|
|
|
per_device_eval_batch_size=1,
|
|
|
|
|
per_device_train_batch_size=1)
|
|
|
|
|
|
|
|
|
|
|