|
|
|
@ -25,8 +25,17 @@ if __name__ == "__main__":
|
|
|
|
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
|
|
|
|
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
|
|
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# example: train a simple L12-D768 RWKV on dummy data
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
|
|
|
|
|
|
|
# --data_file "" --data_type "dummy" --vocab_size 0 \
|
|
|
|
|
|
|
|
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
|
|
|
|
|
|
|
|
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
|
|
|
|
|
|
|
|
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
|
|
|
|
|
|
|
|
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
|
|
|
|
|
|
|
|
|
|
|
|
# example: train a simple L6-D512 RWKV from scratch
|
|
|
|
# example: train a simple L6-D512 RWKV from scratch on enwik8
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
|
|
|
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
|
|
|
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
|
|
|
|
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
|
|
|
|
@ -56,8 +65,8 @@ if __name__ == "__main__":
|
|
|
|
parser = ArgumentParser()
|
|
|
|
parser = ArgumentParser()
|
|
|
|
parser = Trainer.add_argparse_args(parser)
|
|
|
|
parser = Trainer.add_argparse_args(parser)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--load_model", default="", type=str)
|
|
|
|
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
|
|
|
|
parser.add_argument("--wandb", default="", type=str) # wandb project name
|
|
|
|
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
|
|
|
|
parser.add_argument("--proj_dir", default="out", type=str)
|
|
|
|
parser.add_argument("--proj_dir", default="out", type=str)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--data_file", default="", type=str)
|
|
|
|
parser.add_argument("--data_file", default="", type=str)
|
|
|
|
@ -65,26 +74,28 @@ if __name__ == "__main__":
|
|
|
|
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
|
|
|
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--ctx_len", default=1024, type=int)
|
|
|
|
parser.add_argument("--ctx_len", default=1024, type=int)
|
|
|
|
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has xxx steps
|
|
|
|
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
|
|
|
|
parser.add_argument("--epoch_count", default=500, type=int)
|
|
|
|
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
|
|
|
|
parser.add_argument("--epoch_begin", default=0, type=int)
|
|
|
|
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
|
|
|
|
parser.add_argument("--epoch_save", default=5, type=int)
|
|
|
|
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
|
|
|
|
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
|
|
|
|
parser.add_argument("--n_layer", default=6, type=int)
|
|
|
|
parser.add_argument("--n_layer", default=6, type=int)
|
|
|
|
parser.add_argument("--n_embd", default=512, type=int)
|
|
|
|
parser.add_argument("--n_embd", default=512, type=int)
|
|
|
|
parser.add_argument("--pre_ffn", default=0, type=int)
|
|
|
|
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
|
|
|
|
parser.add_argument("--head_qk", default=0, type=int)
|
|
|
|
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick. try 256 if you want to test it
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--lr_init", default=6e-4, type=float)
|
|
|
|
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
|
|
|
|
parser.add_argument("--lr_final", default=1e-5, type=float)
|
|
|
|
parser.add_argument("--lr_final", default=1e-5, type=float)
|
|
|
|
parser.add_argument("--warmup_steps", default=0, type=int)
|
|
|
|
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
|
|
|
|
parser.add_argument("--beta1", default=0.9, type=float)
|
|
|
|
parser.add_argument("--beta1", default=0.9, type=float)
|
|
|
|
parser.add_argument("--beta2", default=0.99, type=float)
|
|
|
|
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
|
|
|
|
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
|
|
|
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
|
|
|
|
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
|
|
|
|
parser.add_argument("--my_pile_mode", default=0, type=int) # my special pile mode
|
|
|
|
parser.add_argument("--my_pile_mode", default=0, type=int) # my special pile mode
|
|
|
|
|
|
|
|
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
|
|
|
|
|
|
|
|
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 500 might be faster (but more VRAM)
|
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
|
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
|
|
|
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
|
|
|
@ -120,7 +131,7 @@ if __name__ == "__main__":
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
|
|
|
|
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, β {args.betas}, eps {args.adam_eps}
|
|
|
|
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
|
|
|
|
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
|
|
|
|
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
|
|
|
|
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
|
|
|
|
@ -134,8 +145,7 @@ if __name__ == "__main__":
|
|
|
|
if not os.path.exists(args.proj_dir):
|
|
|
|
if not os.path.exists(args.proj_dir):
|
|
|
|
os.makedirs(args.proj_dir)
|
|
|
|
os.makedirs(args.proj_dir)
|
|
|
|
|
|
|
|
|
|
|
|
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx"]
|
|
|
|
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy"]
|
|
|
|
assert len(args.data_file) > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.lr_final == 0 or args.lr_init == 0:
|
|
|
|
if args.lr_final == 0 or args.lr_init == 0:
|
|
|
|
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
|
|
|
|
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
|
|
|
|
@ -196,14 +206,17 @@ if __name__ == "__main__":
|
|
|
|
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
|
|
|
|
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
|
|
|
|
|
|
|
|
|
|
|
|
for param_group in trainer.optimizers[0].param_groups:
|
|
|
|
for param_group in trainer.optimizers[0].param_groups:
|
|
|
|
if self.args.my_pile_mode == 0:
|
|
|
|
if args.layerwise_lr > 0:
|
|
|
|
param_group["lr"] = lr * param_group["my_lr_scale"]
|
|
|
|
if self.args.my_pile_mode == 0:
|
|
|
|
elif self.args.my_pile_mode == 2:
|
|
|
|
param_group["lr"] = lr * param_group["my_lr_scale"]
|
|
|
|
if param_group["my_lr_scale"] > 1:
|
|
|
|
elif self.args.my_pile_mode == 2:
|
|
|
|
param_group["lr"] = lr * 5
|
|
|
|
if param_group["my_lr_scale"] > 1:
|
|
|
|
else:
|
|
|
|
param_group["lr"] = lr * 5
|
|
|
|
param_group["lr"] = lr
|
|
|
|
else:
|
|
|
|
# print(param_group["lr"], param_group["my_lr_scale"])
|
|
|
|
param_group["lr"] = lr
|
|
|
|
|
|
|
|
# print(param_group["lr"], param_group["my_lr_scale"])
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
param_group["lr"] = lr
|
|
|
|
|
|
|
|
|
|
|
|
trainer.my_lr = lr
|
|
|
|
trainer.my_lr = lr
|
|
|
|
# rank_zero_info(f"{g_step} {lr}")
|
|
|
|
# rank_zero_info(f"{g_step} {lr}")
|
|
|
|
@ -311,6 +324,9 @@ if __name__ == "__main__":
|
|
|
|
args,
|
|
|
|
args,
|
|
|
|
callbacks=[train_callback(args)],
|
|
|
|
callbacks=[train_callback(args)],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
if "deepspeed" in args.strategy:
|
|
|
|
|
|
|
|
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
|
|
|
|
|
|
|
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
|
|
|
|
|
|
|
|
|
|
|
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
|
|
|
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
|
|
|
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
|
|
|
|
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
|
|
|
|
|