|
|
|
@ -54,10 +54,10 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
k = torch.exp(k)
|
|
|
|
k = torch.exp(k)
|
|
|
|
sum_k = torch.cumsum(k, dim=1)
|
|
|
|
sum_k = torch.cumsum(k, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
k = k.view(B, T, self.n_head, self.head_size)
|
|
|
|
kv = (k * v).view(B, T, self.n_head, self.head_size)
|
|
|
|
v = v.view(B, T, self.n_head, self.head_size)
|
|
|
|
|
|
|
|
|
|
|
|
wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, C)
|
|
|
|
|
|
|
|
|
|
|
|
wkv = (torch.einsum('htu,buhc->bthc', w, k * v)).contiguous().view(B, T, C)
|
|
|
|
|
|
|
|
rwkv = torch.sigmoid(r) * wkv / sum_k
|
|
|
|
rwkv = torch.sigmoid(r) * wkv / sum_k
|
|
|
|
|
|
|
|
|
|
|
|
return self.output(rwkv) * self.time_gamma[:T, :]
|
|
|
|
return self.output(rwkv) * self.time_gamma[:T, :]
|
|
|
|
@ -83,6 +83,7 @@ class RWKV_ChannelMix(nn.Module):
|
|
|
|
r = self.receptance(x)
|
|
|
|
r = self.receptance(x)
|
|
|
|
|
|
|
|
|
|
|
|
wkv = self.weight(F.mish(k) * v) # seems mish is a bit better than gelu
|
|
|
|
wkv = self.weight(F.mish(k) * v) # seems mish is a bit better than gelu
|
|
|
|
|
|
|
|
|
|
|
|
rwkv = torch.sigmoid(r) * wkv
|
|
|
|
rwkv = torch.sigmoid(r) * wkv
|
|
|
|
|
|
|
|
|
|
|
|
return rwkv
|
|
|
|
return rwkv
|
|
|
|
@ -120,7 +121,7 @@ def apply_rotary_pos_emb(q, k, cos, sin):
|
|
|
|
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
|
|
|
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
|
|
|
|
|
|
|
|
|
|
|
class MHA_rotary(nn.Module):
|
|
|
|
class MHA_rotary(nn.Module):
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
def __init__(self, config, layer_id, time_shift = False):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.layer_id = layer_id
|
|
|
|
self.layer_id = layer_id
|
|
|
|
assert config.n_embd % config.n_head == 0
|
|
|
|
assert config.n_embd % config.n_head == 0
|
|
|
|
@ -128,6 +129,9 @@ class MHA_rotary(nn.Module):
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
self.head_size = config.n_embd // config.n_head
|
|
|
|
self.head_size = config.n_embd // config.n_head
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if time_shift:
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
|
|
|
|
|
|
|
self.query = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.query = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
@ -142,6 +146,9 @@ class MHA_rotary(nn.Module):
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
B, T, C = x.size()
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(self, 'time_shift'):
|
|
|
|
|
|
|
|
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
|
|
|
|
|
|
|
|
|
|
|
|
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
@ -160,19 +167,27 @@ class MHA_rotary(nn.Module):
|
|
|
|
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
|
|
|
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
|
|
|
x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
|
|
|
x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
|
|
|
|
|
|
|
|
|
|
|
x = self.output(x) # output projection
|
|
|
|
x = self.output(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class GeGLU(torch.nn.Module):
|
|
|
|
class GeGLU(torch.nn.Module):
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
def __init__(self, config, layer_id, time_shift = False):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.layer_id = layer_id
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if time_shift:
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
|
|
|
|
|
|
|
hidden_sz = 3 * config.n_embd
|
|
|
|
hidden_sz = 3 * config.n_embd
|
|
|
|
self.key = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
self.key = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
self.value = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
self.value = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
self.weight = nn.Linear(hidden_sz, config.n_embd)
|
|
|
|
self.weight = nn.Linear(hidden_sz, config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
|
|
|
if hasattr(self, 'time_shift'):
|
|
|
|
|
|
|
|
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
|
|
|
|
|
|
|
|
|
|
|
|
k = self.key(x)
|
|
|
|
k = self.key(x)
|
|
|
|
v = self.value(x)
|
|
|
|
v = self.value(x)
|
|
|
|
y = self.weight(F.gelu(k) * v)
|
|
|
|
y = self.weight(F.gelu(k) * v)
|
|
|
|
@ -205,7 +220,7 @@ class MHA_pro(nn.Module):
|
|
|
|
self.rotary_ndims = int(self.head_size * 0.5)
|
|
|
|
self.rotary_ndims = int(self.head_size * 0.5)
|
|
|
|
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
|
|
|
|
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
|
|
|
|
|
|
|
|
|
|
|
|
self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads
|
|
|
|
self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads
|
|
|
|
|
|
|
|
|
|
|
|
self.output = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.output = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
@ -218,7 +233,7 @@ class MHA_pro(nn.Module):
|
|
|
|
w = w[:, :, TT-1:] # w is now a circulant matrix
|
|
|
|
w = w[:, :, TT-1:] # w is now a circulant matrix
|
|
|
|
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
|
|
|
|
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
|
|
|
|
|
|
|
|
|
|
|
|
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) # time-mixing
|
|
|
|
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) # time-shift mixing
|
|
|
|
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
@ -300,9 +315,15 @@ class Block(nn.Module):
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
if config.model_type == 'RWKV':
|
|
|
|
self.attn = RWKV_TimeMix(config, layer_id)
|
|
|
|
self.attn = RWKV_TimeMix(config, layer_id)
|
|
|
|
self.mlp = RWKV_ChannelMix(config, layer_id)
|
|
|
|
self.mlp = RWKV_ChannelMix(config, layer_id)
|
|
|
|
|
|
|
|
|
|
|
|
elif config.model_type == 'MHA_rotary':
|
|
|
|
elif config.model_type == 'MHA_rotary':
|
|
|
|
self.attn = MHA_rotary(config, layer_id)
|
|
|
|
self.attn = MHA_rotary(config, layer_id)
|
|
|
|
self.mlp = GeGLU(config, layer_id)
|
|
|
|
self.mlp = GeGLU(config, layer_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif config.model_type == 'MHA_shift':
|
|
|
|
|
|
|
|
self.attn = MHA_rotary(config, layer_id, time_shift=True)
|
|
|
|
|
|
|
|
self.mlp = GeGLU(config, layer_id, time_shift=True)
|
|
|
|
|
|
|
|
|
|
|
|
elif config.model_type == 'MHA_pro':
|
|
|
|
elif config.model_type == 'MHA_pro':
|
|
|
|
self.attn = MHA_pro(config, layer_id)
|
|
|
|
self.attn = MHA_pro(config, layer_id)
|
|
|
|
self.mlp = RWKV_ChannelMix(config, layer_id)
|
|
|
|
self.mlp = RWKV_ChannelMix(config, layer_id)
|
|
|
|
|