|
|
|
@ -51,7 +51,7 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
v = self.value(x)
|
|
|
|
v = self.value(x)
|
|
|
|
r = self.receptance(x)
|
|
|
|
r = self.receptance(x)
|
|
|
|
|
|
|
|
|
|
|
|
k = torch.clamp(k, max=30) # clamp crazy values
|
|
|
|
k = torch.clamp(k, max=30) # clamp extreme values
|
|
|
|
k = torch.exp(k)
|
|
|
|
k = torch.exp(k)
|
|
|
|
sum_k = torch.cumsum(k, dim=1)
|
|
|
|
sum_k = torch.cumsum(k, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
@ -300,6 +300,8 @@ class Block(nn.Module):
|
|
|
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
|
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
|
|
|
|
# self.ln1 = FixedNorm(config.n_embd)
|
|
|
|
|
|
|
|
# self.ln2 = FixedNorm(config.n_embd)
|
|
|
|
self.attn = RWKV_TimeMix(config, layer_id)
|
|
|
|
self.attn = RWKV_TimeMix(config, layer_id)
|
|
|
|
self.mlp = RWKV_ChannelMix(config, layer_id)
|
|
|
|
self.mlp = RWKV_ChannelMix(config, layer_id)
|
|
|
|
|
|
|
|
|
|
|
|
@ -332,6 +334,7 @@ class GPT(nn.Module):
|
|
|
|
self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)])
|
|
|
|
self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)])
|
|
|
|
|
|
|
|
|
|
|
|
self.ln_f = nn.LayerNorm(config.n_embd)
|
|
|
|
self.ln_f = nn.LayerNorm(config.n_embd)
|
|
|
|
|
|
|
|
self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens
|
|
|
|
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
@ -345,30 +348,31 @@ class GPT(nn.Module):
|
|
|
|
ww[k] *= math.sqrt(self.config.vocab_size)
|
|
|
|
ww[k] *= math.sqrt(self.config.vocab_size)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
ww[k] *= math.sqrt(self.config.n_embd)
|
|
|
|
ww[k] *= math.sqrt(self.config.n_embd)
|
|
|
|
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese
|
|
|
|
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese
|
|
|
|
elif 'head.weight' in k:
|
|
|
|
elif 'head.weight' in k:
|
|
|
|
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese
|
|
|
|
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese
|
|
|
|
elif 'blocks.' in k:
|
|
|
|
elif 'blocks.' in k:
|
|
|
|
block_id = int(k.split('.')[1])
|
|
|
|
block_id = int(k.split('.')[1])
|
|
|
|
if 'receptance.weight' in k:
|
|
|
|
if 'receptance.weight' in k:
|
|
|
|
ww[k] *= 0 # 0 works the best
|
|
|
|
ww[k] *= 0 # init with zero matrix
|
|
|
|
elif 'attn.key.weight' in k:
|
|
|
|
elif 'attn.key.weight' in k:
|
|
|
|
ww[k] *= 0 # 0 works the best
|
|
|
|
ww[k] *= 0 # init with zero matrix
|
|
|
|
elif 'attn.output.weight' in k:
|
|
|
|
elif 'attn.output.weight' in k:
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
elif 'mlp.weight.weight' in k:
|
|
|
|
elif 'mlp.weight.weight' in k:
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
elif 'attn.time_w' in k:
|
|
|
|
elif 'attn.time_w' in k:
|
|
|
|
if self.config.n_head > 1: # different time_w for different head
|
|
|
|
curve = torch.tensor([0.9 ** (self.config.ctx_len - 1 - i) for i in range(self.config.ctx_len)])
|
|
|
|
|
|
|
|
curve = curve * 2 + 0.7
|
|
|
|
for h in range(self.config.n_head):
|
|
|
|
for h in range(self.config.n_head):
|
|
|
|
curve = torch.tensor([i for i in range(self.config.ctx_len)]) / (self.config.ctx_len - 1)
|
|
|
|
if self.config.n_head > 1:
|
|
|
|
curve = torch.pow(curve, 24) # concentrated effect
|
|
|
|
|
|
|
|
curve = curve - torch.mean(curve) + 1 # normalize mean to 1
|
|
|
|
|
|
|
|
mix_strength = 1 - 1.2 * h / (self.config.n_head - 1) # mix_strength from 1 to -0.2
|
|
|
|
mix_strength = 1 - 1.2 * h / (self.config.n_head - 1) # mix_strength from 1 to -0.2
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
mix_strength = 0.5
|
|
|
|
ww[k][h] = (1 - mix_strength) + curve * mix_strength
|
|
|
|
ww[k][h] = (1 - mix_strength) + curve * mix_strength
|
|
|
|
# special tweaks because of time_shift
|
|
|
|
# special tweaks because of time_shift
|
|
|
|
ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 2] * 2 + 1) / 3
|
|
|
|
ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 3] * 2 + 1) / 3
|
|
|
|
ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] + 1) / 2
|
|
|
|
ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] * 1 + 2) / 3
|
|
|
|
ww[k][h][self.config.ctx_len - 1] = 1
|
|
|
|
ww[k][h][self.config.ctx_len - 1] = 1
|
|
|
|
# print(k, h, mix_strength, ww[k][h])
|
|
|
|
# print(k, h, mix_strength, ww[k][h])
|
|
|
|
|
|
|
|
|
|
|
|
@ -421,7 +425,7 @@ class GPT(nn.Module):
|
|
|
|
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
|
|
|
|
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
|
|
|
|
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
|
|
|
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
|
|
|
]
|
|
|
|
]
|
|
|
|
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
|
|
|
|
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
|
|
|
|
return optimizer
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, idx, targets=None):
|
|
|
|
def forward(self, idx, targets=None):
|
|
|
|
@ -433,6 +437,7 @@ class GPT(nn.Module):
|
|
|
|
x = self.blocks(x)
|
|
|
|
x = self.blocks(x)
|
|
|
|
|
|
|
|
|
|
|
|
x = self.ln_f(x)
|
|
|
|
x = self.ln_f(x)
|
|
|
|
|
|
|
|
x = x * self.time_out[:, :T, :] # reduce confidence of early tokens
|
|
|
|
x = self.head(x)
|
|
|
|
x = self.head(x)
|
|
|
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
loss = None
|
|
|
|
|