diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 658b93d..b091bef 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -13,6 +13,18 @@ from pytorch_lightning.strategies import DeepSpeedStrategy import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam + +def __nop(ob): + return ob + + +MyModule = nn.Module +MyFunction = __nop +if os.environ["RWKV_JIT"] == "1": + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method + + ######################################################################################################## # CUDA Kernel ######################################################################################################## @@ -88,7 +100,7 @@ def RUN_CUDA(B, T, C, w, u, k, v): ######################################################################################################## -class RWKV_TimeMix(torch.jit.ScriptModule): +class RWKV_TimeMix(MyModule): def __init__(self, config, layer_id): super().__init__() self.layer_id = layer_id @@ -128,7 +140,7 @@ class RWKV_TimeMix(torch.jit.ScriptModule): self.output = nn.Linear(attn_sz, config.n_embd, bias=False) - @torch.jit.script_method + @MyFunction def jit_func(self, x): # Mix x with the previous timestep to produce xk, xv, xr @@ -155,7 +167,7 @@ class RWKV_TimeMix(torch.jit.ScriptModule): return rwkv -class RWKV_ChannelMix(torch.jit.ScriptModule): +class RWKV_ChannelMix(MyModule): def __init__(self, config, layer_id): super().__init__() self.layer_id = layer_id @@ -177,7 +189,7 @@ class RWKV_ChannelMix(torch.jit.ScriptModule): self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False) self.value = nn.Linear(hidden_sz, config.n_embd, bias=False) - @torch.jit.script_method + @MyFunction def forward(self, x): xx = self.time_shift(x) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 397303c..28c0928 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -32,9 +32,9 @@ if __name__ == "__main__": # --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \ # --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \ # --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ - # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 + # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 - # example: fine-tune RWKV 1.5B using 8xA100 40G + # example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M # # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \ # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \ @@ -56,20 +56,20 @@ if __name__ == "__main__": parser = Trainer.add_argparse_args(parser) parser.add_argument("--load_model", default="", type=str) - parser.add_argument("--wandb", default="", type=str) # wandb project name + parser.add_argument("--wandb", default="", type=str) # wandb project name parser.add_argument("--proj_dir", default="out", type=str) parser.add_argument("--data_file", default="", type=str) parser.add_argument("--data_type", default="utf-8", type=str) - parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data) + parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data) parser.add_argument("--ctx_len", default=1024, type=int) - parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has xxx steps + parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has xxx steps parser.add_argument("--epoch_count", default=500, type=int) parser.add_argument("--epoch_begin", default=0, type=int) parser.add_argument("--epoch_save", default=5, type=int) - parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU) + parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU) parser.add_argument("--n_layer", default=6, type=int) parser.add_argument("--n_embd", default=512, type=int) parser.add_argument("--pre_ffn", default=0, type=int) @@ -82,7 +82,7 @@ if __name__ == "__main__": parser.add_argument("--beta2", default=0.99, type=float) parser.add_argument("--adam_eps", default=1e-8, type=float) - parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower + parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower args = parser.parse_args() args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") @@ -114,7 +114,7 @@ if __name__ == "__main__": # Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, β {args.betas}, eps {args.adam_eps} # # Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer -# Found deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer +# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions) # Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer # ############################################################################ @@ -138,6 +138,10 @@ if __name__ == "__main__": if args.precision == "fp16": rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") + os.environ["RWKV_JIT"] = "1" + if "deepspeed_stage_3" in args.strategy: + os.environ["RWKV_JIT"] = "0" + import torch torch.backends.cudnn.benchmark = True @@ -260,7 +264,7 @@ if __name__ == "__main__": args.load_model = f"{args.proj_dir}/rwkv-init.pth" # init weights to tmp file generate_init_weight(model, args.load_model) - print(f"\nLoading {args.load_model}...\n") + print(f"########## Loading {args.load_model}... ##########") load_dict = torch.load(args.load_model, map_location="cpu") model.load_state_dict(load_dict)