diff --git a/src/model.py b/src/model.py index 18b6250..c06bd1d 100644 --- a/src/model.py +++ b/src/model.py @@ -84,7 +84,7 @@ class RWKV_TimeMix(nn.Module): self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1)) self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) - self.time_shift = nn.ZeroPad2d((0,0,1,0)) + self.time_shift = nn.ZeroPad2d((0,0,1,-1)) self.key = nn.Linear(config.n_embd, config.n_attn) self.value = nn.Linear(config.n_embd, config.n_attn) @@ -110,7 +110,7 @@ class RWKV_TimeMix(nn.Module): self.mask = self.mask[:T, :T] w = w.masked_fill(self.mask == 0, 0) - x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) + x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) if hasattr(self, 'tiny_att'): tiny_att = self.tiny_att(x, self.mask) @@ -118,7 +118,7 @@ class RWKV_TimeMix(nn.Module): v = self.value(x) r = self.receptance(x) - k = torch.clamp(k, max=30) # clamp extreme values. e^30 = 10^13 + k = torch.clamp(k, max=30, min=-60) # clamp extreme values. e^30 = 10^13 k = torch.exp(k) sum_k = torch.cumsum(k, dim=1) @@ -138,7 +138,7 @@ class RWKV_ChannelMix(nn.Module): def __init__(self, config, layer_id): super().__init__() self.layer_id = layer_id - self.time_shift = nn.ZeroPad2d((0,0,1,0)) + self.time_shift = nn.ZeroPad2d((0,0,1,-1)) hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of receptance gating self.key = nn.Linear(config.n_embd, hidden_sz) @@ -152,7 +152,7 @@ class RWKV_ChannelMix(nn.Module): def forward(self, x): B, T, C = x.size() - x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) + x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) k = self.key(x) v = self.value(x) r = self.receptance(x) @@ -235,7 +235,7 @@ class MHA_rotary(nn.Module): self.head_size = config.n_attn // config.n_head if time_shift: - self.time_shift = nn.ZeroPad2d((0,0,1,0)) + self.time_shift = nn.ZeroPad2d((0,0,1,-1)) self.query = nn.Linear(config.n_embd, config.n_attn) self.key = nn.Linear(config.n_embd, config.n_attn) @@ -252,7 +252,7 @@ class MHA_rotary(nn.Module): B, T, C = x.size() if hasattr(self, 'time_shift'): - x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) + x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) @@ -281,7 +281,7 @@ class GeGLU(torch.nn.Module): self.layer_id = layer_id if time_shift: - self.time_shift = nn.ZeroPad2d((0,0,1,0)) + self.time_shift = nn.ZeroPad2d((0,0,1,-1)) hidden_sz = 3 * config.n_ffn self.key = nn.Linear(config.n_embd, hidden_sz) @@ -291,7 +291,7 @@ class GeGLU(torch.nn.Module): def forward(self, x): B, T, C = x.size() if hasattr(self, 'time_shift'): - x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) + x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) k = self.key(x) v = self.value(x) @@ -317,7 +317,7 @@ class MHA_pro(nn.Module): self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1)) self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) - self.time_shift = nn.ZeroPad2d((0,0,1,0)) + self.time_shift = nn.ZeroPad2d((0,0,1,-1)) self.query = nn.Linear(config.n_embd, config.n_attn) self.key = nn.Linear(config.n_embd, config.n_attn) self.value = nn.Linear(config.n_embd, config.n_attn) @@ -338,7 +338,7 @@ class MHA_pro(nn.Module): w = w[:, :, TT-1:] # w is now a circulant matrix w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :] - x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) # time-shift mixing + x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) # time-shift mixing q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)