|
|
|
|
@ -234,10 +234,11 @@ class Block(nn.Module):
|
|
|
|
|
self.ffn = RWKV_ChannelMix(args, layer_id)
|
|
|
|
|
|
|
|
|
|
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
|
|
|
|
self.head_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
|
|
|
|
self.head_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
|
|
|
|
self.head_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
|
|
|
|
self.register_buffer("head_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
|
|
|
|
self.tiny_ln = nn.LayerNorm(args.n_embd)
|
|
|
|
|
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
|
|
|
|
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
|
|
|
|
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
|
|
|
|
self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
|
|
|
|
|
|
|
|
|
def forward(self, x, x_emb=None):
|
|
|
|
|
args = self.args
|
|
|
|
|
@ -255,11 +256,12 @@ class Block(nn.Module):
|
|
|
|
|
x = x + self.ffn(self.ln2(x))
|
|
|
|
|
|
|
|
|
|
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
|
|
|
|
q = self.head_q(x)[:, :T, :]
|
|
|
|
|
k = self.head_k(x)[:, :T, :]
|
|
|
|
|
c = (q @ k.transpose(-2, -1)) * (1.0 / args.tiny_att_downscale)
|
|
|
|
|
c = c.masked_fill(self.head_mask[:T, :T] == 0, 0)
|
|
|
|
|
x = x + c @ self.head_v(x_emb)
|
|
|
|
|
xx = self.tiny_ln(x)
|
|
|
|
|
q = self.tiny_q(xx)[:, :T, :]
|
|
|
|
|
k = self.tiny_k(xx)[:, :T, :]
|
|
|
|
|
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
|
|
|
|
|
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
|
|
|
|
|
x = x + c @ self.tiny_v(x_emb)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|