add MHA-plus model

main
BlinkDL 4 years ago
parent bcd4adb781
commit 447eae5841

@ -167,9 +167,72 @@ class GeGLU(torch.nn.Module):
def forward(self, x):
k = self.key(x)
v = self.value(x)
y = self.weight(F.gelu(k) * v)
y = self.weight(F.gelu(k) * v)
return y
########################################################################################################
# Block: MHA+ (with even more tricks)
########################################################################################################
class RotaryMHA_Plus(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.ctx_size = config.ctx_size
self.head_size = config.n_embd // config.n_head
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_size))
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_size))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_size, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_size, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size)))
self.time_shift = nn.ZeroPad2d((0,0,1,0))
self.query = nn.Linear(config.n_embd, config.n_embd)
self.key = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
self.rotary_ndims = int(self.head_size * 0.5)
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads
self.output = nn.Linear(config.n_embd, config.n_embd)
def forward(self, x):
B, T, C = x.size()
TT = self.ctx_size
w = F.pad(self.time_w, (0, TT))
w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT-1:] # w is now a circulant matrix
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) # time-mixing
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
cos, sin = self.rotary_emb(q, seq_len=T)
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
q = torch.cat((q, query_pass), dim=-1)
k = torch.cat((k, key_pass), dim=-1)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
att = F.softmax(att, dim = -1) # softmax
att = att * w # time-weighting
att = self.head_mix(att) # talking heads
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
x = self.output(x) * self.time_gamma[:T, :]
return x
########################################################################################################
# The GPT Model with our blocks
########################################################################################################
@ -205,9 +268,12 @@ class Block(nn.Module):
if config.model_type == 'RWKV':
self.attn = RWKV_TimeMix(config)
self.mlp = RWKV_ChannelMix(config)
else:
elif config.model_type == 'RotaryMHA':
self.attn = RotaryMHA(config)
self.mlp = GeGLU(config)
elif config.model_type == 'MHA-Plus':
self.attn = RotaryMHA_Plus(config)
self.mlp = RWKV_ChannelMix(config)
def forward(self, x):
x = x + self.attn(self.ln1(x))

@ -13,7 +13,10 @@ set_seed(42)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,)
model_type = 'RWKV' # 'RWKV' or 'RotaryMHA'
# RWKV is our proposed model - fastest when the ctx window is long - good performance
# RotaryMHA is usual Multi-head Attention + Rotary Encoding + GeGLU FFN
# MHA-Plus is a bit slow (lots of tricks), with excellent performance
model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' or 'MHA-Plus'
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip
model_level = 'character' # 'character' or 'word'

Loading…
Cancel
Save