|
|
|
|
@ -102,17 +102,18 @@ def RUN_CUDA(B, T, C, w, u, k, v):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RWKV_TimeMix(MyModule):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
def __init__(self, args, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.args = args
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
|
self.n_embd = config.n_embd
|
|
|
|
|
self.ctx_len = args.ctx_len
|
|
|
|
|
self.n_embd = args.n_embd
|
|
|
|
|
|
|
|
|
|
attn_sz = config.n_embd
|
|
|
|
|
attn_sz = args.n_embd
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): # fancy init
|
|
|
|
|
ratio_0_to_1 = layer_id / (config.n_layer - 1) # 0 to 1
|
|
|
|
|
ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0
|
|
|
|
|
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
|
|
|
|
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
|
|
|
|
|
|
|
|
|
# fancy time_decay
|
|
|
|
|
decay_speed = torch.ones(attn_sz)
|
|
|
|
|
@ -126,20 +127,20 @@ class RWKV_TimeMix(MyModule):
|
|
|
|
|
self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
|
|
|
|
|
|
|
|
|
|
# fancy time_mix
|
|
|
|
|
x = torch.ones(1, 1, config.n_embd)
|
|
|
|
|
for i in range(config.n_embd):
|
|
|
|
|
x[0, 0, i] = i / config.n_embd
|
|
|
|
|
x = torch.ones(1, 1, args.n_embd)
|
|
|
|
|
for i in range(args.n_embd):
|
|
|
|
|
x[0, 0, i] = i / args.n_embd
|
|
|
|
|
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
|
|
|
|
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
|
|
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
|
|
|
|
|
|
|
|
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
|
|
|
|
|
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
|
|
|
|
|
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
|
|
|
|
|
self.key = nn.Linear(args.n_embd, attn_sz, bias=False)
|
|
|
|
|
self.value = nn.Linear(args.n_embd, attn_sz, bias=False)
|
|
|
|
|
self.receptance = nn.Linear(args.n_embd, attn_sz, bias=False)
|
|
|
|
|
|
|
|
|
|
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
|
|
|
|
|
self.output = nn.Linear(attn_sz, args.n_embd, bias=False)
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def jit_func(self, x):
|
|
|
|
|
@ -169,26 +170,27 @@ class RWKV_TimeMix(MyModule):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RWKV_ChannelMix(MyModule):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
def __init__(self, args, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.args = args
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): # fancy init of time_mix
|
|
|
|
|
ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0
|
|
|
|
|
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
|
|
|
|
|
|
|
|
|
x = torch.ones(1, 1, config.n_embd)
|
|
|
|
|
for i in range(config.n_embd):
|
|
|
|
|
x[0, 0, i] = i / config.n_embd
|
|
|
|
|
x = torch.ones(1, 1, args.n_embd)
|
|
|
|
|
for i in range(args.n_embd):
|
|
|
|
|
x[0, 0, i] = i / args.n_embd
|
|
|
|
|
|
|
|
|
|
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
|
|
|
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
|
|
|
|
|
|
|
|
hidden_sz = 4 * config.n_embd
|
|
|
|
|
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
|
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
|
|
|
|
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
|
|
|
|
|
hidden_sz = 4 * args.n_embd
|
|
|
|
|
self.key = nn.Linear(args.n_embd, hidden_sz, bias=False)
|
|
|
|
|
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
|
|
|
|
self.value = nn.Linear(hidden_sz, args.n_embd, bias=False)
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
@ -210,32 +212,54 @@ class RWKV_ChannelMix(MyModule):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
def __init__(self, args, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.config = config
|
|
|
|
|
self.args = args
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
|
|
|
|
|
self.ln1 = nn.LayerNorm(config.n_embd)
|
|
|
|
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
|
|
|
self.ln1 = nn.LayerNorm(args.n_embd)
|
|
|
|
|
self.ln2 = nn.LayerNorm(args.n_embd)
|
|
|
|
|
|
|
|
|
|
if self.layer_id == 0:
|
|
|
|
|
self.ln0 = nn.LayerNorm(config.n_embd)
|
|
|
|
|
self.ln0 = nn.LayerNorm(args.n_embd)
|
|
|
|
|
if args.my_pos_emb > 0:
|
|
|
|
|
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
|
|
|
|
|
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
|
|
|
|
|
|
|
|
|
|
if self.layer_id == 0 and self.config.pre_ffn > 0:
|
|
|
|
|
self.ffnPre = RWKV_ChannelMix(config, 0)
|
|
|
|
|
if self.layer_id == 0 and self.args.pre_ffn > 0:
|
|
|
|
|
self.ffnPre = RWKV_ChannelMix(args, 0)
|
|
|
|
|
else:
|
|
|
|
|
self.att = RWKV_TimeMix(config, layer_id)
|
|
|
|
|
self.att = RWKV_TimeMix(args, layer_id)
|
|
|
|
|
|
|
|
|
|
self.ffn = RWKV_ChannelMix(config, layer_id)
|
|
|
|
|
self.ffn = RWKV_ChannelMix(args, layer_id)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
|
|
|
|
self.head_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
|
|
|
|
self.head_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
|
|
|
|
self.head_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
|
|
|
|
self.register_buffer("head_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
|
|
|
|
|
|
|
|
|
def forward(self, x, x_emb=None):
|
|
|
|
|
args = self.args
|
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
if self.layer_id == 0:
|
|
|
|
|
x = self.ln0(x)
|
|
|
|
|
if self.layer_id == 0 and self.config.pre_ffn > 0:
|
|
|
|
|
x = x + self.ffnPre(self.ln1(x)) # better in some cases
|
|
|
|
|
if args.my_pos_emb > 0:
|
|
|
|
|
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
|
|
|
|
|
x = x + pos_emb
|
|
|
|
|
|
|
|
|
|
if self.layer_id == 0 and args.pre_ffn > 0:
|
|
|
|
|
x = x + self.ffnPre(self.ln1(x))
|
|
|
|
|
else:
|
|
|
|
|
x = x + self.att(self.ln1(x))
|
|
|
|
|
x = x + self.ffn(self.ln2(x))
|
|
|
|
|
|
|
|
|
|
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
|
|
|
|
q = self.head_q(x)[:, :T, :]
|
|
|
|
|
k = self.head_k(x)[:, :T, :]
|
|
|
|
|
c = (q @ k.transpose(-2, -1)) * (1.0 / args.tiny_att_downscale)
|
|
|
|
|
c = c.masked_fill(self.head_mask[:T, :T] == 0, 0)
|
|
|
|
|
x = x + c @ self.head_v(x_emb)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -262,9 +286,6 @@ class RWKV(pl.LightningModule):
|
|
|
|
|
self.args = args
|
|
|
|
|
|
|
|
|
|
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
|
|
|
|
if args.my_pos_emb > 0:
|
|
|
|
|
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
|
|
|
|
|
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
|
|
|
|
|
|
|
|
|
@ -330,8 +351,8 @@ class RWKV(pl.LightningModule):
|
|
|
|
|
def deepspeed_offload(self) -> bool:
|
|
|
|
|
strategy = self.trainer.strategy
|
|
|
|
|
if isinstance(strategy, DeepSpeedStrategy):
|
|
|
|
|
config = strategy.config["zero_optimization"]
|
|
|
|
|
return config.get("offload_optimizer") or config.get("offload_param")
|
|
|
|
|
cfg = strategy.config["zero_optimization"]
|
|
|
|
|
return cfg.get("offload_optimizer") or cfg.get("offload_param")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def forward(self, idx):
|
|
|
|
|
@ -340,15 +361,20 @@ class RWKV(pl.LightningModule):
|
|
|
|
|
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
|
|
|
|
|
|
|
|
|
|
x = self.emb(idx)
|
|
|
|
|
if args.my_pos_emb > 0:
|
|
|
|
|
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
|
|
|
|
|
x = x + pos_emb
|
|
|
|
|
x_emb = x
|
|
|
|
|
|
|
|
|
|
for block in self.blocks:
|
|
|
|
|
if args.grad_cp == 1:
|
|
|
|
|
x = deepspeed.checkpointing.checkpoint(block, x)
|
|
|
|
|
else:
|
|
|
|
|
x = block(x)
|
|
|
|
|
if args.tiny_att_dim > 0:
|
|
|
|
|
for block in self.blocks:
|
|
|
|
|
if args.grad_cp == 1:
|
|
|
|
|
x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
|
|
|
|
|
else:
|
|
|
|
|
x = block(x, x_emb)
|
|
|
|
|
else:
|
|
|
|
|
for block in self.blocks:
|
|
|
|
|
if args.grad_cp == 1:
|
|
|
|
|
x = deepspeed.checkpointing.checkpoint(block, x)
|
|
|
|
|
else:
|
|
|
|
|
x = block(x)
|
|
|
|
|
|
|
|
|
|
x = self.ln_out(x)
|
|
|
|
|
|
|
|
|
|
|