RWKV-4a (tinyAtt)

main
BlinkDL 3 years ago
parent 2567c8c904
commit e9b24370d9

@ -102,17 +102,18 @@ def RUN_CUDA(B, T, C, w, u, k, v):
class RWKV_TimeMix(MyModule): class RWKV_TimeMix(MyModule):
def __init__(self, config, layer_id): def __init__(self, args, layer_id):
super().__init__() super().__init__()
self.args = args
self.layer_id = layer_id self.layer_id = layer_id
self.ctx_len = config.ctx_len self.ctx_len = args.ctx_len
self.n_embd = config.n_embd self.n_embd = args.n_embd
attn_sz = config.n_embd attn_sz = args.n_embd
with torch.no_grad(): # fancy init with torch.no_grad(): # fancy init
ratio_0_to_1 = layer_id / (config.n_layer - 1) # 0 to 1 ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0 ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
# fancy time_decay # fancy time_decay
decay_speed = torch.ones(attn_sz) decay_speed = torch.ones(attn_sz)
@ -126,20 +127,20 @@ class RWKV_TimeMix(MyModule):
self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag) self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
# fancy time_mix # fancy time_mix
x = torch.ones(1, 1, config.n_embd) x = torch.ones(1, 1, args.n_embd)
for i in range(config.n_embd): for i in range(args.n_embd):
x[0, 0, i] = i / config.n_embd x[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0)) self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(config.n_embd, attn_sz, bias=False) self.key = nn.Linear(args.n_embd, attn_sz, bias=False)
self.value = nn.Linear(config.n_embd, attn_sz, bias=False) self.value = nn.Linear(args.n_embd, attn_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False) self.receptance = nn.Linear(args.n_embd, attn_sz, bias=False)
self.output = nn.Linear(attn_sz, config.n_embd, bias=False) self.output = nn.Linear(attn_sz, args.n_embd, bias=False)
@MyFunction @MyFunction
def jit_func(self, x): def jit_func(self, x):
@ -169,26 +170,27 @@ class RWKV_TimeMix(MyModule):
class RWKV_ChannelMix(MyModule): class RWKV_ChannelMix(MyModule):
def __init__(self, config, layer_id): def __init__(self, args, layer_id):
super().__init__() super().__init__()
self.args = args
self.layer_id = layer_id self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0 ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
x = torch.ones(1, 1, config.n_embd) x = torch.ones(1, 1, args.n_embd)
for i in range(config.n_embd): for i in range(args.n_embd):
x[0, 0, i] = i / config.n_embd x[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
hidden_sz = 4 * config.n_embd hidden_sz = 4 * args.n_embd
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False) self.key = nn.Linear(args.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False) self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False) self.value = nn.Linear(hidden_sz, args.n_embd, bias=False)
@MyFunction @MyFunction
def forward(self, x): def forward(self, x):
@ -210,32 +212,54 @@ class RWKV_ChannelMix(MyModule):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, config, layer_id): def __init__(self, args, layer_id):
super().__init__() super().__init__()
self.config = config self.args = args
self.layer_id = layer_id self.layer_id = layer_id
self.ln1 = nn.LayerNorm(config.n_embd) self.ln1 = nn.LayerNorm(args.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd) self.ln2 = nn.LayerNorm(args.n_embd)
if self.layer_id == 0: if self.layer_id == 0:
self.ln0 = nn.LayerNorm(config.n_embd) self.ln0 = nn.LayerNorm(args.n_embd)
if args.my_pos_emb > 0:
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
if self.layer_id == 0 and self.config.pre_ffn > 0: if self.layer_id == 0 and self.args.pre_ffn > 0:
self.ffnPre = RWKV_ChannelMix(config, 0) self.ffnPre = RWKV_ChannelMix(args, 0)
else: else:
self.att = RWKV_TimeMix(config, layer_id) self.att = RWKV_TimeMix(args, layer_id)
self.ffn = RWKV_ChannelMix(config, layer_id) self.ffn = RWKV_ChannelMix(args, layer_id)
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
self.head_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.head_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.head_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.register_buffer("head_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
def forward(self, x): def forward(self, x, x_emb=None):
args = self.args
B, T, C = x.size()
if self.layer_id == 0: if self.layer_id == 0:
x = self.ln0(x) x = self.ln0(x)
if self.layer_id == 0 and self.config.pre_ffn > 0: if args.my_pos_emb > 0:
x = x + self.ffnPre(self.ln1(x)) # better in some cases pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
x = x + pos_emb
if self.layer_id == 0 and args.pre_ffn > 0:
x = x + self.ffnPre(self.ln1(x))
else: else:
x = x + self.att(self.ln1(x)) x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x)) x = x + self.ffn(self.ln2(x))
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / args.tiny_att_downscale)
c = c.masked_fill(self.head_mask[:T, :T] == 0, 0)
x = x + c @ self.head_v(x_emb)
return x return x
@ -262,9 +286,6 @@ class RWKV(pl.LightningModule):
self.args = args self.args = args
self.emb = nn.Embedding(args.vocab_size, args.n_embd) self.emb = nn.Embedding(args.vocab_size, args.n_embd)
if args.my_pos_emb > 0:
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)]) self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
@ -330,8 +351,8 @@ class RWKV(pl.LightningModule):
def deepspeed_offload(self) -> bool: def deepspeed_offload(self) -> bool:
strategy = self.trainer.strategy strategy = self.trainer.strategy
if isinstance(strategy, DeepSpeedStrategy): if isinstance(strategy, DeepSpeedStrategy):
config = strategy.config["zero_optimization"] cfg = strategy.config["zero_optimization"]
return config.get("offload_optimizer") or config.get("offload_param") return cfg.get("offload_optimizer") or cfg.get("offload_param")
return False return False
def forward(self, idx): def forward(self, idx):
@ -340,15 +361,20 @@ class RWKV(pl.LightningModule):
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
x = self.emb(idx) x = self.emb(idx)
if args.my_pos_emb > 0: x_emb = x
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
x = x + pos_emb
for block in self.blocks: if args.tiny_att_dim > 0:
if args.grad_cp == 1: for block in self.blocks:
x = deepspeed.checkpointing.checkpoint(block, x) if args.grad_cp == 1:
else: x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
x = block(x) else:
x = block(x, x_emb)
else:
for block in self.blocks:
if args.grad_cp == 1:
x = deepspeed.checkpointing.checkpoint(block, x)
else:
x = block(x)
x = self.ln_out(x) x = self.ln_out(x)

@ -119,10 +119,13 @@ class train_callback(pl.Callback):
to_save_dict[k] = raw_dict[k] to_save_dict[k] = raw_dict[k]
else: else:
to_save_dict = pl_module.state_dict() to_save_dict = pl_module.state_dict()
torch.save( try:
to_save_dict, torch.save(
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", to_save_dict,
) f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
)
except:
pass
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
trainer.my_log.flush() trainer.my_log.flush()

@ -67,7 +67,10 @@ if __name__ == "__main__":
parser.add_argument("--n_layer", default=6, type=int) parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int) parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better) parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick. try 256 if you want to test it parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
parser.add_argument("--tiny_att_downscale", default=0, type=float)
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float) parser.add_argument("--lr_final", default=1e-5, type=float)
@ -232,6 +235,7 @@ if __name__ == "__main__":
os.environ["RWKV_JIT_ON"] = "0" os.environ["RWKV_JIT_ON"] = "0"
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
if args.precision == "fp32": if args.precision == "fp32":
torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False

Loading…
Cancel
Save