|
|
|
@ -57,7 +57,7 @@ class WKV(torch.autograd.Function):
|
|
|
|
k = k.float().contiguous()
|
|
|
|
k = k.float().contiguous()
|
|
|
|
v = v.float().contiguous()
|
|
|
|
v = v.float().contiguous()
|
|
|
|
ctx.save_for_backward(w, u, k, v)
|
|
|
|
ctx.save_for_backward(w, u, k, v)
|
|
|
|
y = torch.empty((B, T, C), device="cuda", memory_format=torch.contiguous_format)
|
|
|
|
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
|
|
|
|
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
|
|
|
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
|
|
|
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
|
|
|
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
|
|
|
return y
|
|
|
|
return y
|
|
|
|
@ -74,10 +74,10 @@ class WKV(torch.autograd.Function):
|
|
|
|
assert T <= T_MAX
|
|
|
|
assert T <= T_MAX
|
|
|
|
assert B * C % min(C, 32) == 0
|
|
|
|
assert B * C % min(C, 32) == 0
|
|
|
|
w, u, k, v = ctx.saved_tensors
|
|
|
|
w, u, k, v = ctx.saved_tensors
|
|
|
|
gw = torch.zeros((B, C), device="cuda").contiguous()
|
|
|
|
gw = torch.zeros((B, C), device=gy.device).contiguous()
|
|
|
|
gu = torch.zeros((B, C), device="cuda").contiguous()
|
|
|
|
gu = torch.zeros((B, C), device=gy.device).contiguous()
|
|
|
|
gk = torch.zeros((B, T, C), device="cuda").contiguous()
|
|
|
|
gk = torch.zeros((B, T, C), device=gy.device).contiguous()
|
|
|
|
gv = torch.zeros((B, T, C), device="cuda").contiguous()
|
|
|
|
gv = torch.zeros((B, T, C), device=gy.device).contiguous()
|
|
|
|
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
|
|
|
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
|
|
|
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
|
|
|
|
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
@ -93,7 +93,7 @@ class WKV(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def RUN_CUDA(B, T, C, w, u, k, v):
|
|
|
|
def RUN_CUDA(B, T, C, w, u, k, v):
|
|
|
|
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
|
|
|
|
return WKV.apply(B, T, C, w, u, k, v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
@ -262,6 +262,9 @@ class RWKV(pl.LightningModule):
|
|
|
|
self.args = args
|
|
|
|
self.args = args
|
|
|
|
|
|
|
|
|
|
|
|
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
|
|
|
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
|
|
|
|
|
|
|
if args.my_pos_emb > 0:
|
|
|
|
|
|
|
|
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
|
|
|
|
|
|
|
|
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
|
|
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
|
|
|
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
|
|
|
|
|
|
|
|
|
|
|
@ -332,31 +335,35 @@ class RWKV(pl.LightningModule):
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, idx):
|
|
|
|
def forward(self, idx):
|
|
|
|
|
|
|
|
args = self.args
|
|
|
|
B, T = idx.size()
|
|
|
|
B, T = idx.size()
|
|
|
|
assert T <= self.args.ctx_len, "Cannot forward, model ctx_len is exhausted."
|
|
|
|
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
|
|
|
|
|
|
|
|
|
|
|
|
x = self.emb(idx)
|
|
|
|
x = self.emb(idx)
|
|
|
|
|
|
|
|
if args.my_pos_emb > 0:
|
|
|
|
|
|
|
|
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
|
|
|
|
|
|
|
|
x = x + pos_emb
|
|
|
|
|
|
|
|
|
|
|
|
for block in self.blocks:
|
|
|
|
for block in self.blocks:
|
|
|
|
if self.args.grad_cp == 1:
|
|
|
|
if args.grad_cp == 1:
|
|
|
|
x = deepspeed.checkpointing.checkpoint(block, x)
|
|
|
|
x = deepspeed.checkpointing.checkpoint(block, x)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
x = block(x)
|
|
|
|
x = block(x)
|
|
|
|
|
|
|
|
|
|
|
|
x = self.ln_out(x)
|
|
|
|
x = self.ln_out(x)
|
|
|
|
|
|
|
|
|
|
|
|
if self.args.head_qk > 0:
|
|
|
|
if args.head_qk > 0:
|
|
|
|
q = self.head_q(x)[:, :T, :]
|
|
|
|
q = self.head_q(x)[:, :T, :]
|
|
|
|
k = self.head_k(x)[:, :T, :]
|
|
|
|
k = self.head_k(x)[:, :T, :]
|
|
|
|
c = (q @ k.transpose(-2, -1)) * (1.0 / self.args.head_qk)
|
|
|
|
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
|
|
|
|
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
|
|
|
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
|
|
|
|
|
|
|
|
|
|
|
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
|
|
|
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
|
|
|
c = c @ F.one_hot(idx, num_classes=self.args.vocab_size)
|
|
|
|
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
|
|
|
|
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
|
|
|
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
|
|
|
c = c @ F.one_hot(idx, num_classes=self.args.vocab_size).half()
|
|
|
|
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
|
|
|
|
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
|
|
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
|
|
c = c @ F.one_hot(idx, num_classes=self.args.vocab_size).bfloat16()
|
|
|
|
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
|
|
|
|
|
|
|
|
|
|
|
|
x = self.head(x) + c
|
|
|
|
x = self.head(x) + c
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
@ -393,7 +400,7 @@ class RWKV(pl.LightningModule):
|
|
|
|
|
|
|
|
|
|
|
|
gain = 1.0
|
|
|
|
gain = 1.0
|
|
|
|
scale = 1.0
|
|
|
|
scale = 1.0
|
|
|
|
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n:
|
|
|
|
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n:
|
|
|
|
m[n] = p
|
|
|
|
m[n] = p
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if n == "emb.weight":
|
|
|
|
if n == "emb.weight":
|
|
|
|
|