|
|
|
|
@ -34,7 +34,7 @@ if __name__ == "__main__":
|
|
|
|
|
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
|
|
|
|
|
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
|
|
|
|
|
|
|
|
|
|
# example: fine-tune RWKV 1.5B using 8xA100 40G
|
|
|
|
|
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
|
|
|
|
|
#
|
|
|
|
|
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
|
|
|
|
|
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
|
|
|
|
|
@ -114,7 +114,7 @@ if __name__ == "__main__":
|
|
|
|
|
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, β {args.betas}, eps {args.adam_eps}
|
|
|
|
|
#
|
|
|
|
|
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
|
|
|
|
|
# Found deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer
|
|
|
|
|
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
|
|
|
|
|
# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
|
|
|
|
|
#
|
|
|
|
|
############################################################################
|
|
|
|
|
@ -138,6 +138,10 @@ if __name__ == "__main__":
|
|
|
|
|
if args.precision == "fp16":
|
|
|
|
|
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
|
|
|
|
|
|
|
|
|
|
os.environ["RWKV_JIT"] = "1"
|
|
|
|
|
if "deepspeed_stage_3" in args.strategy:
|
|
|
|
|
os.environ["RWKV_JIT"] = "0"
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
@ -260,7 +264,7 @@ if __name__ == "__main__":
|
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth" # init weights to tmp file
|
|
|
|
|
generate_init_weight(model, args.load_model)
|
|
|
|
|
|
|
|
|
|
print(f"\nLoading {args.load_model}...\n")
|
|
|
|
|
print(f"########## Loading {args.load_model}... ##########")
|
|
|
|
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
|
|
|
|
model.load_state_dict(load_dict)
|
|
|
|
|
|
|
|
|
|
|