diff --git a/src/model.py b/src/model.py index 40fdbe9..ef22713 100644 --- a/src/model.py +++ b/src/model.py @@ -167,9 +167,72 @@ class GeGLU(torch.nn.Module): def forward(self, x): k = self.key(x) v = self.value(x) - y = self.weight(F.gelu(k) * v) + y = self.weight(F.gelu(k) * v) return y +######################################################################################################## +# Block: MHA+ (with even more tricks) +######################################################################################################## + +class RotaryMHA_Plus(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + self.n_head = config.n_head + self.ctx_size = config.ctx_size + self.head_size = config.n_embd // config.n_head + + self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_size)) + self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_size)) + self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_size, 1)) + self.time_gamma = nn.Parameter(torch.ones(config.ctx_size, 1)) + self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size))) + + self.time_shift = nn.ZeroPad2d((0,0,1,0)) + self.query = nn.Linear(config.n_embd, config.n_embd) + self.key = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + + self.rotary_ndims = int(self.head_size * 0.5) + self.rotary_emb = RotaryEmbedding(self.rotary_ndims) + + self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads + + self.output = nn.Linear(config.n_embd, config.n_embd) + + def forward(self, x): + B, T, C = x.size() + TT = self.ctx_size + w = F.pad(self.time_w, (0, TT)) + w = torch.tile(w, [TT]) + w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) + w = w[:, :, TT-1:] # w is now a circulant matrix + w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :] + + x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) # time-mixing + q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) + k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) + + q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:] + k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:] + cos, sin = self.rotary_emb(q, seq_len=T) + q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding + q = torch.cat((q, query_pass), dim=-1) + k = torch.cat((k, key_pass), dim=-1) + + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) + att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask + att = F.softmax(att, dim = -1) # softmax + att = att * w # time-weighting + att = self.head_mix(att) # talking heads + + x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) + x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) + + x = self.output(x) * self.time_gamma[:T, :] + return x + ######################################################################################################## # The GPT Model with our blocks ######################################################################################################## @@ -205,9 +268,12 @@ class Block(nn.Module): if config.model_type == 'RWKV': self.attn = RWKV_TimeMix(config) self.mlp = RWKV_ChannelMix(config) - else: + elif config.model_type == 'RotaryMHA': self.attn = RotaryMHA(config) self.mlp = GeGLU(config) + elif config.model_type == 'MHA-Plus': + self.attn = RotaryMHA_Plus(config) + self.mlp = RWKV_ChannelMix(config) def forward(self, x): x = x + self.attn(self.ln1(x)) diff --git a/train.py b/train.py index 82e5802..59fe14c 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,10 @@ set_seed(42) np.set_printoptions(precision=4, suppress=True, linewidth=200) logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,) -model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' +# RWKV is our proposed model - fastest when the ctx window is long - good performance +# RotaryMHA is usual Multi-head Attention + Rotary Encoding + GeGLU FFN +# MHA-Plus is a bit slow (lots of tricks), with excellent performance +model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' or 'MHA-Plus' datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip model_level = 'character' # 'character' or 'word'