|
|
|
|
@ -77,7 +77,6 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
|
|
|
|
|
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
|
|
|
|
|
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
|
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
|
|
|
|
|
|
|
|
|
@ -85,8 +84,8 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
|
|
|
|
|
if config.rwkv_tiny_attn > 0:
|
|
|
|
|
self.tiny_att = RWKV_TinyAttn(config)
|
|
|
|
|
# if config.rwkv_tiny_attn > 0:
|
|
|
|
|
# self.tiny_att = RWKV_TinyAttn(config)
|
|
|
|
|
|
|
|
|
|
self.output = nn.Linear(config.n_attn, config.n_embd)
|
|
|
|
|
|
|
|
|
|
@ -102,12 +101,10 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
|
|
|
|
|
w = w[:, :, TT-1:] # w is now a circulant matrix
|
|
|
|
|
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
|
|
|
|
|
self.mask = self.mask[:T, :T]
|
|
|
|
|
w = w.masked_fill(self.mask == 0, 0)
|
|
|
|
|
|
|
|
|
|
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
|
|
|
|
|
if hasattr(self, 'tiny_att'):
|
|
|
|
|
tiny_att = self.tiny_att(x, self.mask)
|
|
|
|
|
# if hasattr(self, 'tiny_att'):
|
|
|
|
|
# tiny_att = self.tiny_att(x, self.mask)
|
|
|
|
|
|
|
|
|
|
k = self.key(x)
|
|
|
|
|
v = self.value(x)
|
|
|
|
|
@ -124,8 +121,8 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
rwkv = torch.sigmoid(r) * wkv / sum_k
|
|
|
|
|
|
|
|
|
|
rwkv = self.output(rwkv)
|
|
|
|
|
if hasattr(self, 'tiny_att'):
|
|
|
|
|
rwkv += tiny_att
|
|
|
|
|
# if hasattr(self, 'tiny_att'):
|
|
|
|
|
# rwkv += tiny_att
|
|
|
|
|
|
|
|
|
|
return rwkv * self.time_gamma[:T, :]
|
|
|
|
|
|
|
|
|
|
@ -437,6 +434,12 @@ class GPT(nn.Module):
|
|
|
|
|
self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens
|
|
|
|
|
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
self.head_q = nn.Linear(config.n_embd, 256)
|
|
|
|
|
self.head_q.scale_init = 0.01
|
|
|
|
|
self.head_k = nn.Linear(config.n_embd, 256)
|
|
|
|
|
self.head_k.scale_init = 0.01
|
|
|
|
|
self.register_buffer("copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
|
|
|
|
|
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
|
|
|
|
|
|
if self.config.model_type == 'RWKV':
|
|
|
|
|
@ -497,8 +500,15 @@ class GPT(nn.Module):
|
|
|
|
|
x = self.blocks(x)
|
|
|
|
|
|
|
|
|
|
x = self.ln_f(x)
|
|
|
|
|
|
|
|
|
|
q = self.head_q(x)[:,:T,:]
|
|
|
|
|
k = self.head_k(x)[:,:T,:]
|
|
|
|
|
c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
|
|
|
|
|
c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
|
|
|
|
|
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
|
|
|
|
|
|
|
|
|
|
x = x * self.time_out[:, :T, :] # reduce confidence of early tokens
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
x = self.head(x) + c
|
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
|
if targets is not None:
|
|
|
|
|
|