better tinyAtt

main
BlinkDL 3 years ago
parent de8bae7778
commit a268cd2e40

@ -234,10 +234,11 @@ class Block(nn.Module):
self.ffn = RWKV_ChannelMix(args, layer_id)
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
self.head_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.head_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.head_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.register_buffer("head_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
self.tiny_ln = nn.LayerNorm(args.n_embd)
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
def forward(self, x, x_emb=None):
args = self.args
@ -255,11 +256,12 @@ class Block(nn.Module):
x = x + self.ffn(self.ln2(x))
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / args.tiny_att_downscale)
c = c.masked_fill(self.head_mask[:T, :T] == 0, 0)
x = x + c @ self.head_v(x_emb)
xx = self.tiny_ln(x)
q = self.tiny_q(xx)[:, :T, :]
k = self.tiny_k(xx)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
x = x + c @ self.tiny_v(x_emb)
return x

@ -70,7 +70,6 @@ if __name__ == "__main__":
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
parser.add_argument("--tiny_att_downscale", default=0, type=float)
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float)

Loading…
Cancel
Save