main
BlinkDL 3 years ago
parent 40237495cc
commit 1479315677

@ -57,7 +57,7 @@ class WKV(torch.autograd.Function):
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device="cuda", memory_format=torch.contiguous_format)
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
return y
@ -74,10 +74,10 @@ class WKV(torch.autograd.Function):
assert T <= T_MAX
assert B * C % min(C, 32) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device="cuda").contiguous()
gu = torch.zeros((B, C), device="cuda").contiguous()
gk = torch.zeros((B, T, C), device="cuda").contiguous()
gv = torch.zeros((B, T, C), device="cuda").contiguous()
gw = torch.zeros((B, C), device=gy.device).contiguous()
gu = torch.zeros((B, C), device=gy.device).contiguous()
gk = torch.zeros((B, T, C), device=gy.device).contiguous()
gv = torch.zeros((B, T, C), device=gy.device).contiguous()
if "32" in os.environ["RWKV_FLOAT_MODE"]:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
else:
@ -93,7 +93,7 @@ class WKV(torch.autograd.Function):
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
return WKV.apply(B, T, C, w, u, k, v)
########################################################################################################
@ -262,6 +262,9 @@ class RWKV(pl.LightningModule):
self.args = args
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
if args.my_pos_emb > 0:
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
@ -332,31 +335,35 @@ class RWKV(pl.LightningModule):
return False
def forward(self, idx):
args = self.args
B, T = idx.size()
assert T <= self.args.ctx_len, "Cannot forward, model ctx_len is exhausted."
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
x = self.emb(idx)
if args.my_pos_emb > 0:
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
x = x + pos_emb
for block in self.blocks:
if self.args.grad_cp == 1:
if args.grad_cp == 1:
x = deepspeed.checkpointing.checkpoint(block, x)
else:
x = block(x)
x = self.ln_out(x)
if self.args.head_qk > 0:
if args.head_qk > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / self.args.head_qk)
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
c = c @ F.one_hot(idx, num_classes=self.args.vocab_size)
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
c = c @ F.one_hot(idx, num_classes=self.args.vocab_size).half()
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
c = c @ F.one_hot(idx, num_classes=self.args.vocab_size).bfloat16()
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
x = self.head(x) + c
else:
@ -393,7 +400,7 @@ class RWKV(pl.LightningModule):
gain = 1.0
scale = 1.0
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n:
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n:
m[n] = p
else:
if n == "emb.weight":

@ -135,11 +135,14 @@ def generate_init_weight(model, init_weight_name):
mm = model.generate_init_weight()
if model.args.my_pile_stage == 1:
print(f"Combine weights from {model.args.load_model}...")
load_dict = torch.load(model.args.load_model, map_location="cpu")
for k in load_dict:
assert k in mm
mm[k] = load_dict[k].reshape(mm[k].shape)
try:
print(f"Combine weights from {model.args.load_model}...")
load_dict = torch.load(model.args.load_model, map_location="cpu")
for k in load_dict:
assert k in mm
mm[k] = load_dict[k].reshape(mm[k].shape)
except:
print(f"\n\n!!! FAIL !!!\n\n")
print(f"Save to {init_weight_name}...")
torch.save(mm, init_weight_name)

Loading…
Cancel
Save