|
|
|
|
@ -167,9 +167,72 @@ class GeGLU(torch.nn.Module):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
k = self.key(x)
|
|
|
|
|
v = self.value(x)
|
|
|
|
|
y = self.weight(F.gelu(k) * v)
|
|
|
|
|
y = self.weight(F.gelu(k) * v)
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
# Block: MHA+ (with even more tricks)
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
class RotaryMHA_Plus(nn.Module):
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert config.n_embd % config.n_head == 0
|
|
|
|
|
self.n_head = config.n_head
|
|
|
|
|
self.ctx_size = config.ctx_size
|
|
|
|
|
self.head_size = config.n_embd // config.n_head
|
|
|
|
|
|
|
|
|
|
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_size))
|
|
|
|
|
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_size))
|
|
|
|
|
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_size, 1))
|
|
|
|
|
self.time_gamma = nn.Parameter(torch.ones(config.ctx_size, 1))
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size)))
|
|
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
self.query = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
|
|
|
|
|
self.rotary_ndims = int(self.head_size * 0.5)
|
|
|
|
|
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
|
|
|
|
|
|
|
|
|
|
self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads
|
|
|
|
|
|
|
|
|
|
self.output = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
TT = self.ctx_size
|
|
|
|
|
w = F.pad(self.time_w, (0, TT))
|
|
|
|
|
w = torch.tile(w, [TT])
|
|
|
|
|
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
|
|
|
|
|
w = w[:, :, TT-1:] # w is now a circulant matrix
|
|
|
|
|
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
|
|
|
|
|
|
|
|
|
|
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) # time-mixing
|
|
|
|
|
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
|
|
|
|
|
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
|
|
|
|
|
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
|
|
|
|
|
cos, sin = self.rotary_emb(q, seq_len=T)
|
|
|
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
|
|
|
|
|
q = torch.cat((q, query_pass), dim=-1)
|
|
|
|
|
k = torch.cat((k, key_pass), dim=-1)
|
|
|
|
|
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
|
|
|
|
|
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
|
|
|
|
|
att = F.softmax(att, dim = -1) # softmax
|
|
|
|
|
att = att * w # time-weighting
|
|
|
|
|
att = self.head_mix(att) # talking heads
|
|
|
|
|
|
|
|
|
|
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
|
|
|
|
x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
|
|
|
|
|
|
|
|
|
x = self.output(x) * self.time_gamma[:T, :]
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
# The GPT Model with our blocks
|
|
|
|
|
########################################################################################################
|
|
|
|
|
@ -205,9 +268,12 @@ class Block(nn.Module):
|
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
|
self.attn = RWKV_TimeMix(config)
|
|
|
|
|
self.mlp = RWKV_ChannelMix(config)
|
|
|
|
|
else:
|
|
|
|
|
elif config.model_type == 'RotaryMHA':
|
|
|
|
|
self.attn = RotaryMHA(config)
|
|
|
|
|
self.mlp = GeGLU(config)
|
|
|
|
|
elif config.model_type == 'MHA-Plus':
|
|
|
|
|
self.attn = RotaryMHA_Plus(config)
|
|
|
|
|
self.mlp = RWKV_ChannelMix(config)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = x + self.attn(self.ln1(x))
|
|
|
|
|
|