|
|
|
|
@ -10,7 +10,7 @@ from torch.nn import functional as F
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
# Block: RWKV Time-mix + RWKV Channel-mix
|
|
|
|
|
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
class RWKV_TimeMix(nn.Module):
|
|
|
|
|
@ -18,15 +18,15 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert config.n_embd % config.n_head == 0
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
self.ctx_size = config.ctx_size
|
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
|
self.n_head = config.n_head
|
|
|
|
|
self.head_size = config.n_embd // config.n_head
|
|
|
|
|
|
|
|
|
|
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_size))
|
|
|
|
|
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_size))
|
|
|
|
|
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_size, 1))
|
|
|
|
|
self.time_gamma = nn.Parameter(torch.ones(config.ctx_size, 1))
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size)))
|
|
|
|
|
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
|
|
|
|
|
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
|
|
|
|
|
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
|
|
|
|
|
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
|
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
|
|
|
|
|
@ -38,7 +38,7 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
TT = self.ctx_size
|
|
|
|
|
TT = self.ctx_len
|
|
|
|
|
w = F.pad(self.time_w, (0, TT))
|
|
|
|
|
w = torch.tile(w, [TT])
|
|
|
|
|
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
|
|
|
|
|
@ -88,7 +88,7 @@ class RWKV_ChannelMix(nn.Module):
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
# Block: Multi-head Attention + Rotary Encoding + GeGLU FFN
|
|
|
|
|
# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
class RotaryEmbedding(torch.nn.Module):
|
|
|
|
|
@ -119,19 +119,20 @@ def apply_rotary_pos_emb(q, k, cos, sin):
|
|
|
|
|
cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:]
|
|
|
|
|
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
|
|
|
|
|
|
|
|
|
class RotaryMHA(nn.Module):
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
class MHA_rotary(nn.Module):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
assert config.n_embd % config.n_head == 0
|
|
|
|
|
self.n_head = config.n_head
|
|
|
|
|
self.ctx_size = config.ctx_size
|
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
|
self.head_size = config.n_embd // config.n_head
|
|
|
|
|
|
|
|
|
|
self.query = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size)))
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
|
|
|
|
|
|
|
|
|
self.rotary_ndims = int(self.head_size * 0.5)
|
|
|
|
|
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
|
|
|
|
|
@ -163,8 +164,9 @@ class RotaryMHA(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
class GeGLU(torch.nn.Module):
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
self.key = nn.Linear(config.n_embd, 3 * config.n_embd)
|
|
|
|
|
self.value = nn.Linear(config.n_embd, 3 * config.n_embd)
|
|
|
|
|
self.weight = nn.Linear(3 * config.n_embd, config.n_embd)
|
|
|
|
|
@ -176,22 +178,23 @@ class GeGLU(torch.nn.Module):
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
# Block: MHA+ (with even more tricks)
|
|
|
|
|
# MHA_pro: with more tricks
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
class RotaryMHA_Plus(nn.Module):
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
class MHA_pro(nn.Module):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
assert config.n_embd % config.n_head == 0
|
|
|
|
|
self.n_head = config.n_head
|
|
|
|
|
self.ctx_size = config.ctx_size
|
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
|
self.head_size = config.n_embd // config.n_head
|
|
|
|
|
|
|
|
|
|
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_size))
|
|
|
|
|
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_size))
|
|
|
|
|
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_size, 1))
|
|
|
|
|
self.time_gamma = nn.Parameter(torch.ones(config.ctx_size, 1))
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size)))
|
|
|
|
|
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
|
|
|
|
|
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
|
|
|
|
|
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
|
|
|
|
|
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
|
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
self.query = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
@ -207,7 +210,7 @@ class RotaryMHA_Plus(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
TT = self.ctx_size
|
|
|
|
|
TT = self.ctx_len
|
|
|
|
|
w = F.pad(self.time_w, (0, TT))
|
|
|
|
|
w = torch.tile(w, [TT])
|
|
|
|
|
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
|
|
|
|
|
@ -280,9 +283,9 @@ class FixedNorm(nn.Module):
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
class GPTConfig:
|
|
|
|
|
def __init__(self, vocab_size, ctx_size, **kwargs):
|
|
|
|
|
def __init__(self, vocab_size, ctx_len, **kwargs):
|
|
|
|
|
self.vocab_size = vocab_size
|
|
|
|
|
self.ctx_size = ctx_size
|
|
|
|
|
self.ctx_len = ctx_len
|
|
|
|
|
for k,v in kwargs.items():
|
|
|
|
|
setattr(self, k, v)
|
|
|
|
|
|
|
|
|
|
@ -298,12 +301,12 @@ class Block(nn.Module):
|
|
|
|
|
self.ln2 = FixedNorm(config.n_embd)
|
|
|
|
|
self.attn = RWKV_TimeMix(config, layer_id)
|
|
|
|
|
self.mlp = RWKV_ChannelMix(config, layer_id)
|
|
|
|
|
elif config.model_type == 'RotaryMHA':
|
|
|
|
|
self.attn = RotaryMHA(config)
|
|
|
|
|
self.mlp = GeGLU(config)
|
|
|
|
|
elif config.model_type == 'MHA-Plus':
|
|
|
|
|
self.attn = RotaryMHA_Plus(config)
|
|
|
|
|
self.mlp = RWKV_ChannelMix(config)
|
|
|
|
|
elif config.model_type == 'MHA_rotary':
|
|
|
|
|
self.attn = MHA_rotary(config, layer_id)
|
|
|
|
|
self.mlp = GeGLU(config, layer_id)
|
|
|
|
|
elif config.model_type == 'MHA_pro':
|
|
|
|
|
self.attn = MHA_pro(config, layer_id)
|
|
|
|
|
self.mlp = RWKV_ChannelMix(config, layer_id)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
@ -328,31 +331,40 @@ class GPT(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
self.ctx_size = config.ctx_size
|
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
if self.config.model_type == 'RWKV': # improve orthogonal weight init
|
|
|
|
|
|
|
|
|
|
token_diversity = pow(self.config.vocab_size / 200, 1/3)
|
|
|
|
|
token_diversity = 0.4 * min(max(token_diversity, 1), 2) # 200 -> 0.4, 1600 -> 0.8. ENG-char 0.4 CHN-char 0.8
|
|
|
|
|
print('token_diversity', token_diversity)
|
|
|
|
|
|
|
|
|
|
ww = self.state_dict()
|
|
|
|
|
for k in ww:
|
|
|
|
|
for k in ww:
|
|
|
|
|
if 'tok_emb' in k:
|
|
|
|
|
if self.config.vocab_size > self.config.n_embd:
|
|
|
|
|
ww[k] *= math.sqrt(self.config.vocab_size)
|
|
|
|
|
else:
|
|
|
|
|
ww[k] *= math.sqrt(self.config.n_embd)
|
|
|
|
|
ww[k] *= 0.4
|
|
|
|
|
ww[k] *= token_diversity
|
|
|
|
|
elif 'head.weight' in k:
|
|
|
|
|
ww[k] *= 0.2
|
|
|
|
|
ww[k] *= token_diversity
|
|
|
|
|
elif 'blocks.' in k:
|
|
|
|
|
block_id = int(k.split('.')[1])
|
|
|
|
|
if 'receptance.weight' in k:
|
|
|
|
|
ww[k] *= 0.5
|
|
|
|
|
ww[k] *= 0.2 # 0.2 ~ 0.5
|
|
|
|
|
elif 'attn.key.weight' in k:
|
|
|
|
|
ww[k] *= 0.2
|
|
|
|
|
ww[k] *= 0.2 # 0.2 ~ 0.5
|
|
|
|
|
elif 'attn.output.weight' in k:
|
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7
|
|
|
|
|
elif 'mlp.weight.weight' in k:
|
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7
|
|
|
|
|
|
|
|
|
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
|
|
|
|
|
|
|
|
|
def get_ctx_size(self):
|
|
|
|
|
return self.ctx_size
|
|
|
|
|
def get_ctx_len(self):
|
|
|
|
|
return self.ctx_len
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, module):
|
|
|
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
|
|
|
@ -403,7 +415,7 @@ class GPT(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, idx, targets=None):
|
|
|
|
|
B, T = idx.size()
|
|
|
|
|
assert T <= self.ctx_size, "Cannot forward, model block size is exhausted."
|
|
|
|
|
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
|
|
|
|
|
|
|
|
|
|
x = self.tok_emb(idx)
|
|
|
|
|
|
|
|
|
|
|