BlinkDL 3 years ago
parent 6ab2e71c25
commit f81349f127

@ -13,6 +13,18 @@ from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
def __nop(ob):
return ob
MyModule = nn.Module
MyFunction = __nop
if os.environ["RWKV_JIT"] == "1":
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
########################################################################################################
# CUDA Kernel
########################################################################################################
@ -88,7 +100,7 @@ def RUN_CUDA(B, T, C, w, u, k, v):
########################################################################################################
class RWKV_TimeMix(torch.jit.ScriptModule):
class RWKV_TimeMix(MyModule):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
@ -128,7 +140,7 @@ class RWKV_TimeMix(torch.jit.ScriptModule):
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
@torch.jit.script_method
@MyFunction
def jit_func(self, x):
# Mix x with the previous timestep to produce xk, xv, xr
@ -155,7 +167,7 @@ class RWKV_TimeMix(torch.jit.ScriptModule):
return rwkv
class RWKV_ChannelMix(torch.jit.ScriptModule):
class RWKV_ChannelMix(MyModule):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
@ -177,7 +189,7 @@ class RWKV_ChannelMix(torch.jit.ScriptModule):
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
@torch.jit.script_method
@MyFunction
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)

@ -32,9 +32,9 @@ if __name__ == "__main__":
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: fine-tune RWKV 1.5B using 8xA100 40G
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
@ -56,20 +56,20 @@ if __name__ == "__main__":
parser = Trainer.add_argparse_args(parser)
parser.add_argument("--load_model", default="", type=str)
parser.add_argument("--wandb", default="", type=str) # wandb project name
parser.add_argument("--wandb", default="", type=str) # wandb project name
parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has xxx steps
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has xxx steps
parser.add_argument("--epoch_count", default=500, type=int)
parser.add_argument("--epoch_begin", default=0, type=int)
parser.add_argument("--epoch_save", default=5, type=int)
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--pre_ffn", default=0, type=int)
@ -82,7 +82,7 @@ if __name__ == "__main__":
parser.add_argument("--beta2", default=0.99, type=float)
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
args = parser.parse_args()
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
@ -114,7 +114,7 @@ if __name__ == "__main__":
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, β {args.betas}, eps {args.adam_eps}
#
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
# Found deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
#
############################################################################
@ -138,6 +138,10 @@ if __name__ == "__main__":
if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
os.environ["RWKV_JIT"] = "1"
if "deepspeed_stage_3" in args.strategy:
os.environ["RWKV_JIT"] = "0"
import torch
torch.backends.cudnn.benchmark = True
@ -260,7 +264,7 @@ if __name__ == "__main__":
args.load_model = f"{args.proj_dir}/rwkv-init.pth" # init weights to tmp file
generate_init_weight(model, args.load_model)
print(f"\nLoading {args.load_model}...\n")
print(f"########## Loading {args.load_model}... ##########")
load_dict = torch.load(args.load_model, map_location="cpu")
model.load_state_dict(load_dict)

Loading…
Cancel
Save