diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 342ac1b..c66f4cf 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -32,7 +32,7 @@ class MyDataset(Dataset): self.data_size = len(self.data._bin_buffer) // 2 rank_zero_info(f"Data has {self.data_size} tokens.") - if args.my_qa_mask == 1: + if args.my_qa_mask > 0: self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document') self.data_pile_size = len(self.data_pile._bin_buffer) // 2 @@ -156,7 +156,7 @@ class MyDataset(Dataset): if args.my_pile_stage > 0: ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank - if args.my_qa_mask == 1: + if args.my_qa_mask > 0: ii_orig = ii if ii % 2 == 0: ii = (ii // 2) * args.magic_prime diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 5c4f786..7635b6c 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -108,12 +108,14 @@ class RWKV_TimeMix(MyModule): self.layer_id = layer_id self.ctx_len = args.ctx_len self.n_embd = args.n_embd - self.my_testing = self.args.my_testing with torch.no_grad(): # fancy init ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 - + ddd = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + ddd[0, 0, i] = i / args.n_embd + # fancy time_decay decay_speed = torch.ones(args.dim_att) for h in range(args.dim_att): @@ -126,12 +128,9 @@ class RWKV_TimeMix(MyModule): self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag) # fancy time_mix - x = torch.ones(1, 1, args.n_embd) - for i in range(args.n_embd): - x[0, 0, i] = i / args.n_embd - self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) - self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0)) + self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.key = nn.Linear(args.n_embd, args.dim_att, bias=False) @@ -147,24 +146,17 @@ class RWKV_TimeMix(MyModule): self.vv = nn.Linear(args.n_embd, d_qkv, bias=False) self.oo = nn.Linear(d_qkv, args.n_embd, bias=False) with torch.no_grad(): - x = torch.ones(1, 1, args.n_embd) - for i in range(args.n_embd): - x[0, 0, i] = i / args.n_embd - self.time_mix_qq = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - self.time_mix_kk = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - self.time_mix_vv = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) if 'a' not in os.environ["RWKV_MY_TESTING"]: @MyFunction def jit_func(self, x): - - # Mix x with the previous timestep to produce xk, xv, xr - xx = self.time_shift(x) + xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) - - # Use xk, xv, xr to produce k, v, r k = self.key(xk) v = self.value(xv) r = self.receptance(xr) @@ -188,25 +180,20 @@ class RWKV_TimeMix(MyModule): @MyFunction def jit_funcQKV(self, x): - # Mix x with the previous timestep to produce xk, xv, xr - xx = self.time_shift(x) + xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq) xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk) xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv) - - # Use xk, xv, xr to produce k, v, r k = self.key(xk) v = self.value(xv) r = self.receptance(xr) sr = torch.sigmoid(r) - qq = self.qq(xqq) kk = self.kk(xkk) vv = self.vv(xvv) - return sr, k, v, qq, kk, vv def forward(self, x): @@ -223,19 +210,16 @@ class RWKV_ChannelMix(MyModule): super().__init__() self.args = args self.layer_id = layer_id - self.my_testing = self.args.my_testing self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): # fancy init of time_mix ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 - - x = torch.ones(1, 1, args.n_embd) + ddd = torch.ones(1, 1, args.n_embd) for i in range(args.n_embd): - x[0, 0, i] = i / args.n_embd - - self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - + ddd[0, 0, i] = i / args.n_embd + self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False) self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) @@ -255,7 +239,6 @@ class MishGLU(MyModule): super().__init__() self.args = args self.layer_id = layer_id - self.my_testing = self.args.my_testing self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): @@ -478,7 +461,7 @@ class RWKV(pl.LightningModule): def training_step(self, batch, batch_idx): args = self.args - if args.my_qa_mask == 0: + if args.my_qa_mask != 1: idx, targets = batch logits = self(idx) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index 4b61621..89407f0 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -154,14 +154,32 @@ def generate_init_weight(model, init_weight_name): mm = model.generate_init_weight() if model.args.my_pile_stage == 1: - try: + if len(model.args.load_model) > 0: print(f"Combine weights from {model.args.load_model}...") load_dict = torch.load(model.args.load_model, map_location="cpu") for k in load_dict: assert k in mm - mm[k] = load_dict[k].reshape(mm[k].shape) - except: - print(f"\n\n!!! FAIL !!!\n\n") + src = load_dict[k] + try: + mm[k] = src.reshape(mm[k].shape) + except: + tmp = mm[k].squeeze().clone() + print(k, src.shape, '-->', mm[k].shape) + ss = src.shape[0] + dd = tmp.shape[0] + for i in range(dd): + pos = i / dd * ss + if pos >= ss - 1: + tmp[i] = src[ss-1] + else: + p0 = int(math.floor(pos)) + ii = pos - p0 + tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii) + mm[k] = tmp.reshape(mm[k].shape) + sss = src.squeeze().float().cpu().numpy() + print(sss[:10], '...', sss[-10:]) + mmm = mm[k].squeeze().float().cpu().numpy() + print(mmm[:10], '...', mmm[-10:]) print(f"Save to {init_weight_name}...") torch.save(mm, init_weight_name)