diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 9b23fbb..5c4f786 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -109,22 +109,21 @@ class RWKV_TimeMix(MyModule): self.ctx_len = args.ctx_len self.n_embd = args.n_embd self.my_testing = self.args.my_testing - attn_sz = args.n_embd with torch.no_grad(): # fancy init ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 # fancy time_decay - decay_speed = torch.ones(attn_sz) - for h in range(attn_sz): - decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1) + decay_speed = torch.ones(args.dim_att) + for h in range(args.dim_att): + decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) self.time_decay = nn.Parameter(decay_speed) # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) # fancy time_first - zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5 - self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag) + zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5 + self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag) # fancy time_mix x = torch.ones(1, 1, args.n_embd) @@ -135,10 +134,10 @@ class RWKV_TimeMix(MyModule): self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - self.key = nn.Linear(args.n_embd, attn_sz, bias=False) - self.value = nn.Linear(args.n_embd, attn_sz, bias=False) - self.receptance = nn.Linear(args.n_embd, attn_sz, bias=False) - self.output = nn.Linear(attn_sz, args.n_embd, bias=False) + self.key = nn.Linear(args.n_embd, args.dim_att, bias=False) + self.value = nn.Linear(args.n_embd, args.dim_att, bias=False) + self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False) + self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) if 'a' in os.environ["RWKV_MY_TESTING"]: self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) @@ -175,7 +174,7 @@ class RWKV_TimeMix(MyModule): def forward(self, x): B, T, C = x.size() # x = (Batch,Time,Channel) sr, k, v = self.jit_func(x) - rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) + rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v) return self.output(rwkv) if 'a' in os.environ["RWKV_MY_TESTING"]: @@ -213,7 +212,7 @@ class RWKV_TimeMix(MyModule): def forward(self, x): B, T, C = x.size() # x = (Batch,Time,Channel) sr, k, v, qq, kk, vv = self.jit_funcQKV(x) - rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) + rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v) rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv)) return rwkv @@ -237,10 +236,9 @@ class RWKV_ChannelMix(MyModule): self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - hidden_sz = 4 * args.n_embd - self.key = nn.Linear(args.n_embd, hidden_sz, bias=False) + self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False) - self.value = nn.Linear(hidden_sz, args.n_embd, bias=False) + self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) @MyFunction def forward(self, x): @@ -252,6 +250,36 @@ class RWKV_ChannelMix(MyModule): kv = self.value(k) return torch.sigmoid(self.receptance(xr)) * kv +class MishGLU(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + self.my_testing = self.args.my_testing + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) + + x = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + x[0, 0, i] = i / args.n_embd + + self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False) + self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False) + self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) + + @MyFunction + def forward(self, x): + xx = self.time_shift(x) + xa = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xb = x * self.time_mix_r + xx * (1 - self.time_mix_r) + a = self.aa(xa) + b = self.bb(xb) + return self.value(a * F.mish(b)) + ######################################################################################################## # The RWKV Model with our blocks ######################################################################################################## @@ -277,7 +305,10 @@ class Block(nn.Module): else: self.att = RWKV_TimeMix(args, layer_id) - self.ffn = RWKV_ChannelMix(args, layer_id) + if 'g' in os.environ["RWKV_MY_TESTING"]: + self.ffn = MishGLU(args, layer_id) + else: + self.ffn = RWKV_ChannelMix(args, layer_id) if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: self.tiny_ln = nn.LayerNorm(args.n_embd) diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 9b63dea..ba63e03 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -67,6 +67,8 @@ if __name__ == "__main__": parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU) parser.add_argument("--n_layer", default=6, type=int) parser.add_argument("--n_embd", default=512, type=int) + parser.add_argument("--dim_att", default=0, type=int) + parser.add_argument("--dim_ffn", default=0, type=int) parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better) parser.add_argument("--head_qk", default=0, type=int) # my headQK trick parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim @@ -139,6 +141,10 @@ if __name__ == "__main__": args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz os.environ["RWKV_T_MAX"] = str(args.ctx_len) os.environ["RWKV_MY_TESTING"] = args.my_testing + if args.dim_att <= 0: + args.dim_att = args.n_embd + if args.dim_ffn <= 0: + args.dim_ffn = args.n_embd * 4 if args.data_type == "wds_img": args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"