main
BlinkDL 3 years ago
parent 605637ca6f
commit aef9f6f7ef

@ -22,7 +22,7 @@ args = types.SimpleNamespace()
######################################################################################################## ########################################################################################################
args.RUN_DEVICE = "cpu" # 'cpu' (already very fast) // 'cuda' args.RUN_DEVICE = "cpu" # 'cpu' (already very fast) // 'cuda'
args.FLOAT_MODE = "fp32" # fp32 // bf16 (saves VRAM, slightly less accurate) args.FLOAT_MODE = "fp32" # fp32 (good for cpu) // fp16 (might overflow) // bf16 (less accurate)
# if args.RUN_DEVICE == "cuda": # if args.RUN_DEVICE == "cuda":
# os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output # os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output
@ -34,7 +34,9 @@ WORD_NAME = [
UNKNOWN_CHAR = None UNKNOWN_CHAR = None
vocab_size = 50277 vocab_size = 50277
# note; you can set MODEL_NAME to your fine-tuned model # Download Pile models: https://huggingface.co/BlinkDL
# or, set MODEL_NAME to your fine-tuned model
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023" # MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
# n_layer = 12 # n_layer = 12
# n_embd = 768 # n_embd = 768
@ -50,21 +52,16 @@ vocab_size = 50277
# n_embd = 2048 # n_embd = 2048
# ctx_len = 1024 # ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220929-ctx4096' # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023'
# n_layer = 24 # n_layer = 32
# n_embd = 2048 # n_embd = 2560
# ctx_len = 4096 # ctx_len = 1024
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783' MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
n_layer = 32 n_layer = 32
n_embd = 2560 n_embd = 4096
ctx_len = 1024 ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221004-3047'
# n_layer = 32
# n_embd = 4096
# ctx_len = 1024
args.MODEL_NAME = MODEL_NAME args.MODEL_NAME = MODEL_NAME
args.n_layer = n_layer args.n_layer = n_layer
args.n_embd = n_embd args.n_embd = n_embd

@ -55,6 +55,8 @@ class RWKV_RNN(nn.Module):
w[x] = w[x].float() w[x] = w[x].float()
elif self.FLOAT_MODE == "bf16": elif self.FLOAT_MODE == "bf16":
w[x] = w[x].bfloat16() w[x] = w[x].bfloat16()
elif self.FLOAT_MODE == "fp16":
w[x] = w[x].half()
w[x].requires_grad = False w[x].requires_grad = False
if args.RUN_DEVICE == 'cuda' and x != 'emb.weight': if args.RUN_DEVICE == 'cuda' and x != 'emb.weight':
@ -106,6 +108,10 @@ class RWKV_RNN(nn.Module):
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k) xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r) xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+0] = x.float() state[5*i+0] = x.float()
elif self.FLOAT_MODE == "fp16":
xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r)
state[5*i+0] = x.float()
else: else:
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k) xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r) xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
@ -124,6 +130,11 @@ class RWKV_RNN(nn.Module):
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v) xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r) xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+1] = x.float() state[5*i+1] = x.float()
elif self.FLOAT_MODE == "fp16":
xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r)
state[5*i+1] = x.float()
else: else:
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k) xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v) xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
@ -134,7 +145,7 @@ class RWKV_RNN(nn.Module):
k = kw @ xk k = kw @ xk
v = vw @ xv v = vw @ xv
if self.FLOAT_MODE == "bf16": if '16' in self.FLOAT_MODE:
kk = k.float() kk = k.float()
vv = v.float() vv = v.float()
else: else:
@ -158,6 +169,8 @@ class RWKV_RNN(nn.Module):
state[5*i+4] = p state[5*i+4] = p
if self.FLOAT_MODE == "bf16": if self.FLOAT_MODE == "bf16":
wkv = (a / b).type(torch.bfloat16) wkv = (a / b).type(torch.bfloat16)
elif self.FLOAT_MODE == "fp16":
wkv = (a / b).half()
else: else:
wkv = a / b wkv = a / b

Loading…
Cancel
Save