|
|
|
|
@ -84,7 +84,7 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
|
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
|
|
|
|
|
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
@ -110,7 +110,7 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
self.mask = self.mask[:T, :T]
|
|
|
|
|
w = w.masked_fill(self.mask == 0, 0)
|
|
|
|
|
|
|
|
|
|
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
|
|
|
|
|
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
|
|
|
|
|
if hasattr(self, 'tiny_att'):
|
|
|
|
|
tiny_att = self.tiny_att(x, self.mask)
|
|
|
|
|
|
|
|
|
|
@ -118,7 +118,7 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
v = self.value(x)
|
|
|
|
|
r = self.receptance(x)
|
|
|
|
|
|
|
|
|
|
k = torch.clamp(k, max=30) # clamp extreme values. e^30 = 10^13
|
|
|
|
|
k = torch.clamp(k, max=30, min=-60) # clamp extreme values. e^30 = 10^13
|
|
|
|
|
k = torch.exp(k)
|
|
|
|
|
sum_k = torch.cumsum(k, dim=1)
|
|
|
|
|
|
|
|
|
|
@ -138,7 +138,7 @@ class RWKV_ChannelMix(nn.Module):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
|
|
|
|
|
|
|
|
|
hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of receptance gating
|
|
|
|
|
self.key = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
|
@ -152,7 +152,7 @@ class RWKV_ChannelMix(nn.Module):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
|
|
|
|
|
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
|
|
|
|
|
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
|
|
|
|
|
k = self.key(x)
|
|
|
|
|
v = self.value(x)
|
|
|
|
|
r = self.receptance(x)
|
|
|
|
|
@ -235,7 +235,7 @@ class MHA_rotary(nn.Module):
|
|
|
|
|
self.head_size = config.n_attn // config.n_head
|
|
|
|
|
|
|
|
|
|
if time_shift:
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
|
|
|
|
|
|
|
|
|
self.query = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
@ -252,7 +252,7 @@ class MHA_rotary(nn.Module):
|
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
|
|
|
|
|
if hasattr(self, 'time_shift'):
|
|
|
|
|
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
|
|
|
|
|
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
|
|
|
|
|
|
|
|
|
|
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
@ -281,7 +281,7 @@ class GeGLU(torch.nn.Module):
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
|
|
|
|
|
if time_shift:
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
|
|
|
|
|
|
|
|
|
hidden_sz = 3 * config.n_ffn
|
|
|
|
|
self.key = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
|
@ -291,7 +291,7 @@ class GeGLU(torch.nn.Module):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
if hasattr(self, 'time_shift'):
|
|
|
|
|
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
|
|
|
|
|
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
|
|
|
|
|
|
|
|
|
|
k = self.key(x)
|
|
|
|
|
v = self.value(x)
|
|
|
|
|
@ -317,7 +317,7 @@ class MHA_pro(nn.Module):
|
|
|
|
|
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
|
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
|
|
|
|
self.query = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
@ -338,7 +338,7 @@ class MHA_pro(nn.Module):
|
|
|
|
|
w = w[:, :, TT-1:] # w is now a circulant matrix
|
|
|
|
|
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
|
|
|
|
|
|
|
|
|
|
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) # time-shift mixing
|
|
|
|
|
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) # time-shift mixing
|
|
|
|
|
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
|