|
|
|
|
@ -33,11 +33,11 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.output = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
TT = self.ctx_len
|
|
|
|
|
w = F.pad(self.time_w, (0, TT))
|
|
|
|
|
w = torch.tile(w, [TT])
|
|
|
|
|
@ -51,7 +51,7 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
v = self.value(x)
|
|
|
|
|
r = self.receptance(x)
|
|
|
|
|
|
|
|
|
|
k = torch.clamp(k, max=30) # clamp crazy values
|
|
|
|
|
k = torch.clamp(k, max=30) # clamp extreme values
|
|
|
|
|
k = torch.exp(k)
|
|
|
|
|
sum_k = torch.cumsum(k, dim=1)
|
|
|
|
|
|
|
|
|
|
@ -154,7 +154,7 @@ class MHA_rotary(nn.Module):
|
|
|
|
|
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
|
|
|
|
|
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
|
|
|
|
|
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
|
|
|
|
|
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
|
|
|
|
|
cos, sin = self.rotary_emb(q, seq_len=T)
|
|
|
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
|
|
|
|
|
@ -163,7 +163,7 @@ class MHA_rotary(nn.Module):
|
|
|
|
|
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
|
|
|
|
|
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
|
|
|
|
|
att = F.softmax(att, dim = -1) # softmax
|
|
|
|
|
att = F.softmax(att, dim = -1) # softmax
|
|
|
|
|
|
|
|
|
|
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
|
|
|
|
x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
|
|
|
|
@ -196,7 +196,7 @@ class GeGLU(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
# MHA_pro: with more tricks
|
|
|
|
|
########################################################################################################
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
class MHA_pro(nn.Module):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
@ -211,7 +211,7 @@ class MHA_pro(nn.Module):
|
|
|
|
|
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
|
|
|
|
|
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
|
|
|
|
|
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
|
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
self.query = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
@ -239,7 +239,7 @@ class MHA_pro(nn.Module):
|
|
|
|
|
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
|
|
|
|
|
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
|
|
|
|
|
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
|
|
|
|
|
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
|
|
|
|
|
cos, sin = self.rotary_emb(q, seq_len=T)
|
|
|
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
|
|
|
|
|
@ -283,7 +283,7 @@ class FixedNorm(nn.Module):
|
|
|
|
|
x_normed = x / (norm_x * self.dd + 1e-12)
|
|
|
|
|
return x_normed
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
class GPTConfig:
|
|
|
|
|
def __init__(self, vocab_size, ctx_len, **kwargs):
|
|
|
|
|
@ -300,6 +300,8 @@ class Block(nn.Module):
|
|
|
|
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
|
|
|
|
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
|
# self.ln1 = FixedNorm(config.n_embd)
|
|
|
|
|
# self.ln2 = FixedNorm(config.n_embd)
|
|
|
|
|
self.attn = RWKV_TimeMix(config, layer_id)
|
|
|
|
|
self.mlp = RWKV_ChannelMix(config, layer_id)
|
|
|
|
|
|
|
|
|
|
@ -332,6 +334,7 @@ class GPT(nn.Module):
|
|
|
|
|
self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)])
|
|
|
|
|
|
|
|
|
|
self.ln_f = nn.LayerNorm(config.n_embd)
|
|
|
|
|
self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens
|
|
|
|
|
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
|
@ -345,32 +348,33 @@ class GPT(nn.Module):
|
|
|
|
|
ww[k] *= math.sqrt(self.config.vocab_size)
|
|
|
|
|
else:
|
|
|
|
|
ww[k] *= math.sqrt(self.config.n_embd)
|
|
|
|
|
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese
|
|
|
|
|
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese
|
|
|
|
|
elif 'head.weight' in k:
|
|
|
|
|
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese
|
|
|
|
|
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese
|
|
|
|
|
elif 'blocks.' in k:
|
|
|
|
|
block_id = int(k.split('.')[1])
|
|
|
|
|
if 'receptance.weight' in k:
|
|
|
|
|
ww[k] *= 0 # 0 works the best
|
|
|
|
|
ww[k] *= 0 # init with zero matrix
|
|
|
|
|
elif 'attn.key.weight' in k:
|
|
|
|
|
ww[k] *= 0 # 0 works the best
|
|
|
|
|
ww[k] *= 0 # init with zero matrix
|
|
|
|
|
elif 'attn.output.weight' in k:
|
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
|
elif 'mlp.weight.weight' in k:
|
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
|
elif 'attn.time_w' in k:
|
|
|
|
|
if self.config.n_head > 1: # different time_w for different head
|
|
|
|
|
for h in range(self.config.n_head):
|
|
|
|
|
curve = torch.tensor([i for i in range(self.config.ctx_len)]) / (self.config.ctx_len - 1)
|
|
|
|
|
curve = torch.pow(curve, 24) # concentrated effect
|
|
|
|
|
curve = curve - torch.mean(curve) + 1 # normalize mean to 1
|
|
|
|
|
curve = torch.tensor([0.9 ** (self.config.ctx_len - 1 - i) for i in range(self.config.ctx_len)])
|
|
|
|
|
curve = curve * 2 + 0.7
|
|
|
|
|
for h in range(self.config.n_head):
|
|
|
|
|
if self.config.n_head > 1:
|
|
|
|
|
mix_strength = 1 - 1.2 * h / (self.config.n_head - 1) # mix_strength from 1 to -0.2
|
|
|
|
|
ww[k][h] = (1 - mix_strength) + curve * mix_strength
|
|
|
|
|
# special tweaks because of time_shift
|
|
|
|
|
ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 2] * 2 + 1) / 3
|
|
|
|
|
ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] + 1) / 2
|
|
|
|
|
ww[k][h][self.config.ctx_len - 1] = 1
|
|
|
|
|
# print(k, h, mix_strength, ww[k][h])
|
|
|
|
|
else:
|
|
|
|
|
mix_strength = 0.5
|
|
|
|
|
ww[k][h] = (1 - mix_strength) + curve * mix_strength
|
|
|
|
|
# special tweaks because of time_shift
|
|
|
|
|
ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 3] * 2 + 1) / 3
|
|
|
|
|
ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] * 1 + 2) / 3
|
|
|
|
|
ww[k][h][self.config.ctx_len - 1] = 1
|
|
|
|
|
# print(k, h, mix_strength, ww[k][h])
|
|
|
|
|
|
|
|
|
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
|
|
|
|
|
|
|
|
|
@ -421,7 +425,7 @@ class GPT(nn.Module):
|
|
|
|
|
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
|
|
|
|
|
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
|
|
|
|
]
|
|
|
|
|
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
|
|
|
|
|
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
|
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
|
|
def forward(self, idx, targets=None):
|
|
|
|
|
@ -433,6 +437,7 @@ class GPT(nn.Module):
|
|
|
|
|
x = self.blocks(x)
|
|
|
|
|
|
|
|
|
|
x = self.ln_f(x)
|
|
|
|
|
x = x * self.time_out[:, :T, :] # reduce confidence of early tokens
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
|
|