diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 05d97bb..1e0ae1b 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -57,7 +57,7 @@ class WKV(torch.autograd.Function): k = k.float().contiguous() v = v.float().contiguous() ctx.save_for_backward(w, u, k, v) - y = torch.empty((B, T, C), device="cuda", memory_format=torch.contiguous_format) + y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format) wkv_cuda.forward(B, T, C, w, u, k, v, y) if "32" in os.environ["RWKV_FLOAT_MODE"]: return y @@ -74,10 +74,10 @@ class WKV(torch.autograd.Function): assert T <= T_MAX assert B * C % min(C, 32) == 0 w, u, k, v = ctx.saved_tensors - gw = torch.zeros((B, C), device="cuda").contiguous() - gu = torch.zeros((B, C), device="cuda").contiguous() - gk = torch.zeros((B, T, C), device="cuda").contiguous() - gv = torch.zeros((B, T, C), device="cuda").contiguous() + gw = torch.zeros((B, C), device=gy.device).contiguous() + gu = torch.zeros((B, C), device=gy.device).contiguous() + gk = torch.zeros((B, T, C), device=gy.device).contiguous() + gv = torch.zeros((B, T, C), device=gy.device).contiguous() if "32" in os.environ["RWKV_FLOAT_MODE"]: wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) else: @@ -93,7 +93,7 @@ class WKV(torch.autograd.Function): def RUN_CUDA(B, T, C, w, u, k, v): - return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) + return WKV.apply(B, T, C, w, u, k, v) ######################################################################################################## @@ -262,6 +262,9 @@ class RWKV(pl.LightningModule): self.args = args self.emb = nn.Embedding(args.vocab_size, args.n_embd) + if args.my_pos_emb > 0: + self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd))) + self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd))) self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)]) @@ -332,31 +335,35 @@ class RWKV(pl.LightningModule): return False def forward(self, idx): + args = self.args B, T = idx.size() - assert T <= self.args.ctx_len, "Cannot forward, model ctx_len is exhausted." + assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." x = self.emb(idx) + if args.my_pos_emb > 0: + pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:] + x = x + pos_emb for block in self.blocks: - if self.args.grad_cp == 1: + if args.grad_cp == 1: x = deepspeed.checkpointing.checkpoint(block, x) else: x = block(x) x = self.ln_out(x) - if self.args.head_qk > 0: + if args.head_qk > 0: q = self.head_q(x)[:, :T, :] k = self.head_k(x)[:, :T, :] - c = (q @ k.transpose(-2, -1)) * (1.0 / self.args.head_qk) + c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) if "32" in os.environ["RWKV_FLOAT_MODE"]: - c = c @ F.one_hot(idx, num_classes=self.args.vocab_size) + c = c @ F.one_hot(idx, num_classes=args.vocab_size) elif os.environ["RWKV_FLOAT_MODE"] == "fp16": - c = c @ F.one_hot(idx, num_classes=self.args.vocab_size).half() + c = c @ F.one_hot(idx, num_classes=args.vocab_size).half() elif os.environ["RWKV_FLOAT_MODE"] == "bf16": - c = c @ F.one_hot(idx, num_classes=self.args.vocab_size).bfloat16() + c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() x = self.head(x) + c else: @@ -393,7 +400,7 @@ class RWKV(pl.LightningModule): gain = 1.0 scale = 1.0 - if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n: + if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n: m[n] = p else: if n == "emb.weight": diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index 8a648eb..c29979b 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -135,11 +135,14 @@ def generate_init_weight(model, init_weight_name): mm = model.generate_init_weight() if model.args.my_pile_stage == 1: - print(f"Combine weights from {model.args.load_model}...") - load_dict = torch.load(model.args.load_model, map_location="cpu") - for k in load_dict: - assert k in mm - mm[k] = load_dict[k].reshape(mm[k].shape) + try: + print(f"Combine weights from {model.args.load_model}...") + load_dict = torch.load(model.args.load_model, map_location="cpu") + for k in load_dict: + assert k in mm + mm[k] = load_dict[k].reshape(mm[k].shape) + except: + print(f"\n\n!!! FAIL !!!\n\n") print(f"Save to {init_weight_name}...") torch.save(mm, init_weight_name)