main
BlinkDL 3 years ago
parent 605637ca6f
commit aef9f6f7ef

@ -22,7 +22,7 @@ args = types.SimpleNamespace()
########################################################################################################
args.RUN_DEVICE = "cpu" # 'cpu' (already very fast) // 'cuda'
args.FLOAT_MODE = "fp32" # fp32 // bf16 (saves VRAM, slightly less accurate)
args.FLOAT_MODE = "fp32" # fp32 (good for cpu) // fp16 (might overflow) // bf16 (less accurate)
# if args.RUN_DEVICE == "cuda":
# os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output
@ -34,7 +34,9 @@ WORD_NAME = [
UNKNOWN_CHAR = None
vocab_size = 50277
# note; you can set MODEL_NAME to your fine-tuned model
# Download Pile models: https://huggingface.co/BlinkDL
# or, set MODEL_NAME to your fine-tuned model
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
# n_layer = 12
# n_embd = 768
@ -50,21 +52,16 @@ vocab_size = 50277
# n_embd = 2048
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220929-ctx4096'
# n_layer = 24
# n_embd = 2048
# ctx_len = 4096
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023'
# n_layer = 32
# n_embd = 2560
# ctx_len = 1024
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
n_layer = 32
n_embd = 2560
n_embd = 4096
ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221004-3047'
# n_layer = 32
# n_embd = 4096
# ctx_len = 1024
args.MODEL_NAME = MODEL_NAME
args.n_layer = n_layer
args.n_embd = n_embd

@ -55,6 +55,8 @@ class RWKV_RNN(nn.Module):
w[x] = w[x].float()
elif self.FLOAT_MODE == "bf16":
w[x] = w[x].bfloat16()
elif self.FLOAT_MODE == "fp16":
w[x] = w[x].half()
w[x].requires_grad = False
if args.RUN_DEVICE == 'cuda' and x != 'emb.weight':
@ -106,6 +108,10 @@ class RWKV_RNN(nn.Module):
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+0] = x.float()
elif self.FLOAT_MODE == "fp16":
xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r)
state[5*i+0] = x.float()
else:
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
@ -124,6 +130,11 @@ class RWKV_RNN(nn.Module):
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+1] = x.float()
elif self.FLOAT_MODE == "fp16":
xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r)
state[5*i+1] = x.float()
else:
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
@ -134,7 +145,7 @@ class RWKV_RNN(nn.Module):
k = kw @ xk
v = vw @ xv
if self.FLOAT_MODE == "bf16":
if '16' in self.FLOAT_MODE:
kk = k.float()
vv = v.float()
else:
@ -158,6 +169,8 @@ class RWKV_RNN(nn.Module):
state[5*i+4] = p
if self.FLOAT_MODE == "bf16":
wkv = (a / b).type(torch.bfloat16)
elif self.FLOAT_MODE == "fp16":
wkv = (a / b).half()
else:
wkv = a / b

Loading…
Cancel
Save