From e9b24370d96b0fd7316c4a99e9ba68dba3039327 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Fri, 25 Nov 2022 12:02:43 +0000 Subject: [PATCH] RWKV-4a (tinyAtt) --- RWKV-v4neo/src/model.py | 120 +++++++++++++++++++++++--------------- RWKV-v4neo/src/trainer.py | 11 ++-- RWKV-v4neo/train.py | 6 +- 3 files changed, 85 insertions(+), 52 deletions(-) diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 1e0ae1b..0b961d6 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -102,17 +102,18 @@ def RUN_CUDA(B, T, C, w, u, k, v): class RWKV_TimeMix(MyModule): - def __init__(self, config, layer_id): + def __init__(self, args, layer_id): super().__init__() + self.args = args self.layer_id = layer_id - self.ctx_len = config.ctx_len - self.n_embd = config.n_embd + self.ctx_len = args.ctx_len + self.n_embd = args.n_embd - attn_sz = config.n_embd + attn_sz = args.n_embd with torch.no_grad(): # fancy init - ratio_0_to_1 = layer_id / (config.n_layer - 1) # 0 to 1 - ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0 + ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 # fancy time_decay decay_speed = torch.ones(attn_sz) @@ -126,20 +127,20 @@ class RWKV_TimeMix(MyModule): self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag) # fancy time_mix - x = torch.ones(1, 1, config.n_embd) - for i in range(config.n_embd): - x[0, 0, i] = i / config.n_embd + x = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + x[0, 0, i] = i / args.n_embd self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - self.key = nn.Linear(config.n_embd, attn_sz, bias=False) - self.value = nn.Linear(config.n_embd, attn_sz, bias=False) - self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False) + self.key = nn.Linear(args.n_embd, attn_sz, bias=False) + self.value = nn.Linear(args.n_embd, attn_sz, bias=False) + self.receptance = nn.Linear(args.n_embd, attn_sz, bias=False) - self.output = nn.Linear(attn_sz, config.n_embd, bias=False) + self.output = nn.Linear(attn_sz, args.n_embd, bias=False) @MyFunction def jit_func(self, x): @@ -169,26 +170,27 @@ class RWKV_TimeMix(MyModule): class RWKV_ChannelMix(MyModule): - def __init__(self, config, layer_id): + def __init__(self, args, layer_id): super().__init__() + self.args = args self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): # fancy init of time_mix - ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0 + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 - x = torch.ones(1, 1, config.n_embd) - for i in range(config.n_embd): - x[0, 0, i] = i / config.n_embd + x = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + x[0, 0, i] = i / args.n_embd self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - hidden_sz = 4 * config.n_embd - self.key = nn.Linear(config.n_embd, hidden_sz, bias=False) - self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False) - self.value = nn.Linear(hidden_sz, config.n_embd, bias=False) + hidden_sz = 4 * args.n_embd + self.key = nn.Linear(args.n_embd, hidden_sz, bias=False) + self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False) + self.value = nn.Linear(hidden_sz, args.n_embd, bias=False) @MyFunction def forward(self, x): @@ -210,32 +212,54 @@ class RWKV_ChannelMix(MyModule): class Block(nn.Module): - def __init__(self, config, layer_id): + def __init__(self, args, layer_id): super().__init__() - self.config = config + self.args = args self.layer_id = layer_id - self.ln1 = nn.LayerNorm(config.n_embd) - self.ln2 = nn.LayerNorm(config.n_embd) + self.ln1 = nn.LayerNorm(args.n_embd) + self.ln2 = nn.LayerNorm(args.n_embd) if self.layer_id == 0: - self.ln0 = nn.LayerNorm(config.n_embd) + self.ln0 = nn.LayerNorm(args.n_embd) + if args.my_pos_emb > 0: + self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd))) + self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd))) - if self.layer_id == 0 and self.config.pre_ffn > 0: - self.ffnPre = RWKV_ChannelMix(config, 0) + if self.layer_id == 0 and self.args.pre_ffn > 0: + self.ffnPre = RWKV_ChannelMix(args, 0) else: - self.att = RWKV_TimeMix(config, layer_id) + self.att = RWKV_TimeMix(args, layer_id) - self.ffn = RWKV_ChannelMix(config, layer_id) + self.ffn = RWKV_ChannelMix(args, layer_id) + + if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: + self.head_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) + self.head_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) + self.head_v = nn.Linear(args.n_embd, args.n_embd, bias=False) + self.register_buffer("head_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) - def forward(self, x): + def forward(self, x, x_emb=None): + args = self.args + B, T, C = x.size() if self.layer_id == 0: x = self.ln0(x) - if self.layer_id == 0 and self.config.pre_ffn > 0: - x = x + self.ffnPre(self.ln1(x)) # better in some cases + if args.my_pos_emb > 0: + pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:] + x = x + pos_emb + + if self.layer_id == 0 and args.pre_ffn > 0: + x = x + self.ffnPre(self.ln1(x)) else: x = x + self.att(self.ln1(x)) x = x + self.ffn(self.ln2(x)) + + if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: + q = self.head_q(x)[:, :T, :] + k = self.head_k(x)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (1.0 / args.tiny_att_downscale) + c = c.masked_fill(self.head_mask[:T, :T] == 0, 0) + x = x + c @ self.head_v(x_emb) return x @@ -262,9 +286,6 @@ class RWKV(pl.LightningModule): self.args = args self.emb = nn.Embedding(args.vocab_size, args.n_embd) - if args.my_pos_emb > 0: - self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd))) - self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd))) self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)]) @@ -330,8 +351,8 @@ class RWKV(pl.LightningModule): def deepspeed_offload(self) -> bool: strategy = self.trainer.strategy if isinstance(strategy, DeepSpeedStrategy): - config = strategy.config["zero_optimization"] - return config.get("offload_optimizer") or config.get("offload_param") + cfg = strategy.config["zero_optimization"] + return cfg.get("offload_optimizer") or cfg.get("offload_param") return False def forward(self, idx): @@ -340,15 +361,20 @@ class RWKV(pl.LightningModule): assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." x = self.emb(idx) - if args.my_pos_emb > 0: - pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:] - x = x + pos_emb + x_emb = x - for block in self.blocks: - if args.grad_cp == 1: - x = deepspeed.checkpointing.checkpoint(block, x) - else: - x = block(x) + if args.tiny_att_dim > 0: + for block in self.blocks: + if args.grad_cp == 1: + x = deepspeed.checkpointing.checkpoint(block, x, x_emb) + else: + x = block(x, x_emb) + else: + for block in self.blocks: + if args.grad_cp == 1: + x = deepspeed.checkpointing.checkpoint(block, x) + else: + x = block(x) x = self.ln_out(x) diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index c29979b..6f1b230 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -119,10 +119,13 @@ class train_callback(pl.Callback): to_save_dict[k] = raw_dict[k] else: to_save_dict = pl_module.state_dict() - torch.save( - to_save_dict, - f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", - ) + try: + torch.save( + to_save_dict, + f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", + ) + except: + pass trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") trainer.my_log.flush() diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 0834559..580546e 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -67,7 +67,10 @@ if __name__ == "__main__": parser.add_argument("--n_layer", default=6, type=int) parser.add_argument("--n_embd", default=512, type=int) parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better) - parser.add_argument("--head_qk", default=0, type=int) # my headQK trick. try 256 if you want to test it + parser.add_argument("--head_qk", default=0, type=int) # my headQK trick + parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim + parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer + parser.add_argument("--tiny_att_downscale", default=0, type=float) parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 parser.add_argument("--lr_final", default=1e-5, type=float) @@ -232,6 +235,7 @@ if __name__ == "__main__": os.environ["RWKV_JIT_ON"] = "0" torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True if args.precision == "fp32": torch.backends.cudnn.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False