|
|
|
@ -58,10 +58,9 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
v = v.view(B, T, self.n_head, self.head_size)
|
|
|
|
v = v.view(B, T, self.n_head, self.head_size)
|
|
|
|
|
|
|
|
|
|
|
|
wkv = (torch.einsum('htu,buhc->bthc', w, k * v)).contiguous().view(B, T, C)
|
|
|
|
wkv = (torch.einsum('htu,buhc->bthc', w, k * v)).contiguous().view(B, T, C)
|
|
|
|
y = torch.sigmoid(r) * wkv / sum_k
|
|
|
|
rwkv = torch.sigmoid(r) * wkv / sum_k
|
|
|
|
|
|
|
|
|
|
|
|
y = self.output(y) * self.time_gamma[:T, :]
|
|
|
|
return self.output(rwkv) * self.time_gamma[:T, :]
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RWKV_ChannelMix(nn.Module):
|
|
|
|
class RWKV_ChannelMix(nn.Module):
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
@ -82,10 +81,10 @@ class RWKV_ChannelMix(nn.Module):
|
|
|
|
v = self.value(x)
|
|
|
|
v = self.value(x)
|
|
|
|
r = self.receptance(x)
|
|
|
|
r = self.receptance(x)
|
|
|
|
|
|
|
|
|
|
|
|
wkv = self.weight(F.mish(k) * v) # mish is a bit better than gelu
|
|
|
|
wkv = self.weight(F.mish(k) * v) # seems mish is a bit better than gelu
|
|
|
|
y = torch.sigmoid(r) * wkv
|
|
|
|
rwkv = torch.sigmoid(r) * wkv
|
|
|
|
|
|
|
|
|
|
|
|
return y
|
|
|
|
return rwkv
|
|
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
|
|
|
|
# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
|
|
|
|
@ -335,11 +334,6 @@ class GPT(nn.Module):
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.model_type == 'RWKV': # improve orthogonal weight init
|
|
|
|
if self.config.model_type == 'RWKV': # improve orthogonal weight init
|
|
|
|
|
|
|
|
|
|
|
|
token_diversity = pow(self.config.vocab_size / 200, 1/3)
|
|
|
|
|
|
|
|
token_diversity = 0.4 * min(max(token_diversity, 1), 2) # 200 -> 0.4, 1600 -> 0.8. ENG-char 0.4 CHN-char 0.8
|
|
|
|
|
|
|
|
print('token_diversity', token_diversity)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ww = self.state_dict()
|
|
|
|
ww = self.state_dict()
|
|
|
|
for k in ww:
|
|
|
|
for k in ww:
|
|
|
|
if 'tok_emb' in k:
|
|
|
|
if 'tok_emb' in k:
|
|
|
|
@ -347,19 +341,19 @@ class GPT(nn.Module):
|
|
|
|
ww[k] *= math.sqrt(self.config.vocab_size)
|
|
|
|
ww[k] *= math.sqrt(self.config.vocab_size)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
ww[k] *= math.sqrt(self.config.n_embd)
|
|
|
|
ww[k] *= math.sqrt(self.config.n_embd)
|
|
|
|
ww[k] *= token_diversity
|
|
|
|
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese
|
|
|
|
elif 'head.weight' in k:
|
|
|
|
elif 'head.weight' in k:
|
|
|
|
ww[k] *= token_diversity
|
|
|
|
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese
|
|
|
|
elif 'blocks.' in k:
|
|
|
|
elif 'blocks.' in k:
|
|
|
|
block_id = int(k.split('.')[1])
|
|
|
|
block_id = int(k.split('.')[1])
|
|
|
|
if 'receptance.weight' in k:
|
|
|
|
if 'receptance.weight' in k:
|
|
|
|
ww[k] *= 0.2 # 0.2 ~ 0.5
|
|
|
|
ww[k] *= 0.2 # 0.2 ~ 0.5 gives similar results
|
|
|
|
elif 'attn.key.weight' in k:
|
|
|
|
elif 'attn.key.weight' in k:
|
|
|
|
ww[k] *= 0.2 # 0.2 ~ 0.5
|
|
|
|
ww[k] *= 0.2 # 0.2 ~ 0.5 gives similar results
|
|
|
|
elif 'attn.output.weight' in k:
|
|
|
|
elif 'attn.output.weight' in k:
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
elif 'mlp.weight.weight' in k:
|
|
|
|
elif 'mlp.weight.weight' in k:
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
|
|
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
|
|
|
|
|
|
|
|
|
|
|
@ -426,6 +420,6 @@ class GPT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
loss = None
|
|
|
|
if targets is not None:
|
|
|
|
if targets is not None:
|
|
|
|
loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(x.view(-1, x.size(-1)), targets.view(-1))
|
|
|
|
loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(x.view(-1, x.size(-1)), targets.view(-1)) # try increasing smoothing if you see nan
|
|
|
|
|
|
|
|
|
|
|
|
return x, loss
|
|
|
|
return x, loss
|
|
|
|
|