|
|
|
|
@ -32,9 +32,9 @@ if __name__ == "__main__":
|
|
|
|
|
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
|
|
|
|
|
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
|
|
|
|
|
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
|
|
|
|
|
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
|
|
|
|
|
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
|
|
|
|
|
|
|
|
|
|
# example: fine-tune RWKV 1.5B using 8xA100 40G
|
|
|
|
|
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
|
|
|
|
|
#
|
|
|
|
|
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
|
|
|
|
|
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
|
|
|
|
|
@ -56,20 +56,20 @@ if __name__ == "__main__":
|
|
|
|
|
parser = Trainer.add_argparse_args(parser)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--load_model", default="", type=str)
|
|
|
|
|
parser.add_argument("--wandb", default="", type=str) # wandb project name
|
|
|
|
|
parser.add_argument("--wandb", default="", type=str) # wandb project name
|
|
|
|
|
parser.add_argument("--proj_dir", default="out", type=str)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--data_file", default="", type=str)
|
|
|
|
|
parser.add_argument("--data_type", default="utf-8", type=str)
|
|
|
|
|
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
|
|
|
|
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--ctx_len", default=1024, type=int)
|
|
|
|
|
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has xxx steps
|
|
|
|
|
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has xxx steps
|
|
|
|
|
parser.add_argument("--epoch_count", default=500, type=int)
|
|
|
|
|
parser.add_argument("--epoch_begin", default=0, type=int)
|
|
|
|
|
parser.add_argument("--epoch_save", default=5, type=int)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
|
|
|
|
|
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
|
|
|
|
|
parser.add_argument("--n_layer", default=6, type=int)
|
|
|
|
|
parser.add_argument("--n_embd", default=512, type=int)
|
|
|
|
|
parser.add_argument("--pre_ffn", default=0, type=int)
|
|
|
|
|
@ -82,7 +82,7 @@ if __name__ == "__main__":
|
|
|
|
|
parser.add_argument("--beta2", default=0.99, type=float)
|
|
|
|
|
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
|
|
|
|
|
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
|
|
|
|
@ -114,7 +114,7 @@ if __name__ == "__main__":
|
|
|
|
|
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, β {args.betas}, eps {args.adam_eps}
|
|
|
|
|
#
|
|
|
|
|
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
|
|
|
|
|
# Found deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer
|
|
|
|
|
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
|
|
|
|
|
# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
|
|
|
|
|
#
|
|
|
|
|
############################################################################
|
|
|
|
|
@ -138,6 +138,10 @@ if __name__ == "__main__":
|
|
|
|
|
if args.precision == "fp16":
|
|
|
|
|
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
|
|
|
|
|
|
|
|
|
|
os.environ["RWKV_JIT"] = "1"
|
|
|
|
|
if "deepspeed_stage_3" in args.strategy:
|
|
|
|
|
os.environ["RWKV_JIT"] = "0"
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
@ -260,7 +264,7 @@ if __name__ == "__main__":
|
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth" # init weights to tmp file
|
|
|
|
|
generate_init_weight(model, args.load_model)
|
|
|
|
|
|
|
|
|
|
print(f"\nLoading {args.load_model}...\n")
|
|
|
|
|
print(f"########## Loading {args.load_model}... ##########")
|
|
|
|
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
|
|
|
|
model.load_state_dict(load_dict)
|
|
|
|
|
|
|
|
|
|
|