|
|
|
|
@ -14,9 +14,10 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
class RWKV_TimeMix(nn.Module):
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert config.n_embd % config.n_head == 0
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
self.ctx_size = config.ctx_size
|
|
|
|
|
self.n_head = config.n_head
|
|
|
|
|
self.head_size = config.n_embd // config.n_head
|
|
|
|
|
@ -58,13 +59,14 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
|
|
|
|
|
wkv = (torch.einsum('htu,buhc->bthc', w, k * v)).contiguous().view(B, T, C)
|
|
|
|
|
y = torch.sigmoid(r) * wkv / sum_k
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y = self.output(y) * self.time_gamma[:T, :]
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
class RWKV_ChannelMix(nn.Module):
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
|
|
|
|
|
self.key = nn.Linear(config.n_embd, 3 * config.n_embd)
|
|
|
|
|
@ -265,7 +267,7 @@ class RMSNorm(nn.Module):
|
|
|
|
|
x_normed = x / (norm_x * self.dd + 1e-12)
|
|
|
|
|
return self.weight * x_normed
|
|
|
|
|
|
|
|
|
|
class SimpleRMSNorm(nn.Module):
|
|
|
|
|
class FixedNorm(nn.Module):
|
|
|
|
|
def __init__(self, d):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.dd = d ** (-1. / 2)
|
|
|
|
|
@ -285,18 +287,17 @@ class GPTConfig:
|
|
|
|
|
setattr(self, k, v)
|
|
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.ln1 = nn.LayerNorm(config.n_embd)
|
|
|
|
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
|
|
|
|
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
|
self.ln1 = nn.Identity()
|
|
|
|
|
# self.ln1 = SimpleRMSNorm(config.n_embd) # turn on this if you see nan in large RWKV models
|
|
|
|
|
self.ln2 = SimpleRMSNorm(config.n_embd)
|
|
|
|
|
self.attn = RWKV_TimeMix(config)
|
|
|
|
|
self.mlp = RWKV_ChannelMix(config)
|
|
|
|
|
self.ln1 = FixedNorm(config.n_embd)
|
|
|
|
|
self.ln2 = FixedNorm(config.n_embd)
|
|
|
|
|
self.attn = RWKV_TimeMix(config, layer_id)
|
|
|
|
|
self.mlp = RWKV_ChannelMix(config, layer_id)
|
|
|
|
|
elif config.model_type == 'RotaryMHA':
|
|
|
|
|
self.attn = RotaryMHA(config)
|
|
|
|
|
self.mlp = GeGLU(config)
|
|
|
|
|
@ -305,6 +306,7 @@ class Block(nn.Module):
|
|
|
|
|
self.mlp = RWKV_ChannelMix(config)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
x = x + self.attn(self.ln1(x))
|
|
|
|
|
x = x + self.mlp(self.ln2(x))
|
|
|
|
|
|
|
|
|
|
@ -317,10 +319,10 @@ class GPT(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
|
|
|
|
self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)])
|
|
|
|
|
|
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
|
self.ln_f = SimpleRMSNorm(config.n_embd)
|
|
|
|
|
self.ln_f = FixedNorm(config.n_embd)
|
|
|
|
|
else:
|
|
|
|
|
self.ln_f = nn.LayerNorm(config.n_embd)
|
|
|
|
|
|
|
|
|
|
@ -329,15 +331,23 @@ class GPT(nn.Module):
|
|
|
|
|
self.ctx_size = config.ctx_size
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
if self.config.model_type == 'RWKV':
|
|
|
|
|
if self.config.model_type == 'RWKV': # improve orthogonal weight init
|
|
|
|
|
ww = self.state_dict()
|
|
|
|
|
for k in ww: # reduce weight to avoid nan
|
|
|
|
|
if 'receptance.weight' in k:
|
|
|
|
|
ww[k] /= math.pow(config.n_embd, 0.5)
|
|
|
|
|
elif 'key.weight' in k:
|
|
|
|
|
ww[k] /= math.pow(config.n_embd, 0.25)
|
|
|
|
|
elif 'value.weight' in k:
|
|
|
|
|
ww[k] /= math.pow(config.n_embd, 0.25)
|
|
|
|
|
for k in ww:
|
|
|
|
|
if 'tok_emb' in k:
|
|
|
|
|
if self.config.vocab_size > self.config.n_embd:
|
|
|
|
|
ww[k] *= math.sqrt(self.config.vocab_size)
|
|
|
|
|
else:
|
|
|
|
|
ww[k] *= math.sqrt(self.config.n_embd)
|
|
|
|
|
ww[k] *= 0.4
|
|
|
|
|
elif 'head.weight' in k:
|
|
|
|
|
ww[k] *= 0.2
|
|
|
|
|
elif 'blocks.' in k:
|
|
|
|
|
block_id = int(k.split('.')[1])
|
|
|
|
|
if 'receptance.weight' in k:
|
|
|
|
|
ww[k] *= 0.5
|
|
|
|
|
elif 'attn.key.weight' in k:
|
|
|
|
|
ww[k] *= 0.2
|
|
|
|
|
|
|
|
|
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
|
|
|
|
|
|
|
|
|
@ -347,14 +357,14 @@ class GPT(nn.Module):
|
|
|
|
|
def _init_weights(self, module):
|
|
|
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
|
|
|
if self.config.model_type == 'RWKV':
|
|
|
|
|
gain = 1.0
|
|
|
|
|
if isinstance(module, nn.Linear):
|
|
|
|
|
gain_layer = min(3, module.weight.shape[0] / module.weight.shape[1])
|
|
|
|
|
depth_factor = 1 # min(1, 1 / math.sqrt(self.config.n_layer / 5))
|
|
|
|
|
nn.init.orthogonal_(module.weight, gain = gain_layer * depth_factor)
|
|
|
|
|
else:
|
|
|
|
|
nn.init.orthogonal_(module.weight, gain = 1.0)
|
|
|
|
|
if module.weight.data.shape[0] > module.weight.data.shape[1]:
|
|
|
|
|
gain = math.sqrt(module.weight.data.shape[0] / module.weight.data.shape[1])
|
|
|
|
|
nn.init.orthogonal_(module.weight, gain=gain)
|
|
|
|
|
else:
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=0.01)
|
|
|
|
|
|
|
|
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
|
|
|
module.bias.data.zero_()
|
|
|
|
|
|
|
|
|
|
@ -396,14 +406,14 @@ class GPT(nn.Module):
|
|
|
|
|
assert T <= self.ctx_size, "Cannot forward, model block size is exhausted."
|
|
|
|
|
|
|
|
|
|
x = self.tok_emb(idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.blocks(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.ln_f(x)
|
|
|
|
|
logits = self.head(x)
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
|
if targets is not None:
|
|
|
|
|
loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(logits.view(-1, logits.size(-1)), targets.view(-1))
|
|
|
|
|
loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(x.view(-1, x.size(-1)), targets.view(-1))
|
|
|
|
|
|
|
|
|
|
return logits, loss
|
|
|
|
|
return x, loss
|
|
|
|
|
|