diff --git a/RWKV-v4neo/run.py b/RWKV-v4neo/run.py index 2f28e63..61eb3f8 100644 --- a/RWKV-v4neo/run.py +++ b/RWKV-v4neo/run.py @@ -22,7 +22,7 @@ args = types.SimpleNamespace() ######################################################################################################## args.RUN_DEVICE = "cpu" # 'cpu' (already very fast) // 'cuda' -args.FLOAT_MODE = "fp32" # fp32 // bf16 (saves VRAM, slightly less accurate) +args.FLOAT_MODE = "fp32" # fp32 (good for cpu) // fp16 (might overflow) // bf16 (less accurate) # if args.RUN_DEVICE == "cuda": # os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output @@ -34,7 +34,9 @@ WORD_NAME = [ UNKNOWN_CHAR = None vocab_size = 50277 -# note; you can set MODEL_NAME to your fine-tuned model +# Download Pile models: https://huggingface.co/BlinkDL +# or, set MODEL_NAME to your fine-tuned model + # MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023" # n_layer = 12 # n_embd = 768 @@ -50,21 +52,16 @@ vocab_size = 50277 # n_embd = 2048 # ctx_len = 1024 -# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220929-ctx4096' -# n_layer = 24 -# n_embd = 2048 -# ctx_len = 4096 +# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023' +# n_layer = 32 +# n_embd = 2560 +# ctx_len = 1024 -MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783' +MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047' n_layer = 32 -n_embd = 2560 +n_embd = 4096 ctx_len = 1024 -# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221004-3047' -# n_layer = 32 -# n_embd = 4096 -# ctx_len = 1024 - args.MODEL_NAME = MODEL_NAME args.n_layer = n_layer args.n_embd = n_embd diff --git a/RWKV-v4neo/src/model_run.py b/RWKV-v4neo/src/model_run.py index f3325eb..c12fee4 100644 --- a/RWKV-v4neo/src/model_run.py +++ b/RWKV-v4neo/src/model_run.py @@ -55,6 +55,8 @@ class RWKV_RNN(nn.Module): w[x] = w[x].float() elif self.FLOAT_MODE == "bf16": w[x] = w[x].bfloat16() + elif self.FLOAT_MODE == "fp16": + w[x] = w[x].half() w[x].requires_grad = False if args.RUN_DEVICE == 'cuda' and x != 'emb.weight': @@ -106,6 +108,10 @@ class RWKV_RNN(nn.Module): xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k) xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r) state[5*i+0] = x.float() + elif self.FLOAT_MODE == "fp16": + xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k) + xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r) + state[5*i+0] = x.float() else: xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k) xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r) @@ -124,6 +130,11 @@ class RWKV_RNN(nn.Module): xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v) xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r) state[5*i+1] = x.float() + elif self.FLOAT_MODE == "fp16": + xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k) + xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v) + xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r) + state[5*i+1] = x.float() else: xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k) xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v) @@ -134,7 +145,7 @@ class RWKV_RNN(nn.Module): k = kw @ xk v = vw @ xv - if self.FLOAT_MODE == "bf16": + if '16' in self.FLOAT_MODE: kk = k.float() vv = v.float() else: @@ -158,6 +169,8 @@ class RWKV_RNN(nn.Module): state[5*i+4] = p if self.FLOAT_MODE == "bf16": wkv = (a / b).type(torch.bfloat16) + elif self.FLOAT_MODE == "fp16": + wkv = (a / b).half() else: wkv = a / b