diff --git a/src/model.py b/src/model.py index 7cebae9..41808a3 100644 --- a/src/model.py +++ b/src/model.py @@ -54,10 +54,10 @@ class RWKV_TimeMix(nn.Module): k = torch.exp(k) sum_k = torch.cumsum(k, dim=1) - k = k.view(B, T, self.n_head, self.head_size) - v = v.view(B, T, self.n_head, self.head_size) + kv = (k * v).view(B, T, self.n_head, self.head_size) + + wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, C) - wkv = (torch.einsum('htu,buhc->bthc', w, k * v)).contiguous().view(B, T, C) rwkv = torch.sigmoid(r) * wkv / sum_k return self.output(rwkv) * self.time_gamma[:T, :] @@ -83,6 +83,7 @@ class RWKV_ChannelMix(nn.Module): r = self.receptance(x) wkv = self.weight(F.mish(k) * v) # seems mish is a bit better than gelu + rwkv = torch.sigmoid(r) * wkv return rwkv @@ -120,7 +121,7 @@ def apply_rotary_pos_emb(q, k, cos, sin): return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) class MHA_rotary(nn.Module): - def __init__(self, config, layer_id): + def __init__(self, config, layer_id, time_shift = False): super().__init__() self.layer_id = layer_id assert config.n_embd % config.n_head == 0 @@ -128,6 +129,9 @@ class MHA_rotary(nn.Module): self.ctx_len = config.ctx_len self.head_size = config.n_embd // config.n_head + if time_shift: + self.time_shift = nn.ZeroPad2d((0,0,1,0)) + self.query = nn.Linear(config.n_embd, config.n_embd) self.key = nn.Linear(config.n_embd, config.n_embd) self.value = nn.Linear(config.n_embd, config.n_embd) @@ -142,6 +146,9 @@ class MHA_rotary(nn.Module): def forward(self, x): B, T, C = x.size() + if hasattr(self, 'time_shift'): + x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) + q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) @@ -160,19 +167,27 @@ class MHA_rotary(nn.Module): x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) - x = self.output(x) # output projection + x = self.output(x) return x class GeGLU(torch.nn.Module): - def __init__(self, config, layer_id): + def __init__(self, config, layer_id, time_shift = False): super().__init__() self.layer_id = layer_id + + if time_shift: + self.time_shift = nn.ZeroPad2d((0,0,1,0)) + hidden_sz = 3 * config.n_embd self.key = nn.Linear(config.n_embd, hidden_sz) self.value = nn.Linear(config.n_embd, hidden_sz) self.weight = nn.Linear(hidden_sz, config.n_embd) def forward(self, x): + B, T, C = x.size() + if hasattr(self, 'time_shift'): + x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) + k = self.key(x) v = self.value(x) y = self.weight(F.gelu(k) * v) @@ -205,7 +220,7 @@ class MHA_pro(nn.Module): self.rotary_ndims = int(self.head_size * 0.5) self.rotary_emb = RotaryEmbedding(self.rotary_ndims) - self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads + self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads self.output = nn.Linear(config.n_embd, config.n_embd) @@ -218,7 +233,7 @@ class MHA_pro(nn.Module): w = w[:, :, TT-1:] # w is now a circulant matrix w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :] - x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) # time-mixing + x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) # time-shift mixing q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) @@ -300,9 +315,15 @@ class Block(nn.Module): if config.model_type == 'RWKV': self.attn = RWKV_TimeMix(config, layer_id) self.mlp = RWKV_ChannelMix(config, layer_id) + elif config.model_type == 'MHA_rotary': self.attn = MHA_rotary(config, layer_id) self.mlp = GeGLU(config, layer_id) + + elif config.model_type == 'MHA_shift': + self.attn = MHA_rotary(config, layer_id, time_shift=True) + self.mlp = GeGLU(config, layer_id, time_shift=True) + elif config.model_type == 'MHA_pro': self.attn = MHA_pro(config, layer_id) self.mlp = RWKV_ChannelMix(config, layer_id) diff --git a/src/trainer.py b/src/trainer.py index fe54bbd..1fd5123 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -40,8 +40,10 @@ class Trainer: if 'wandb' in sys.modules: cfg = model.config + for k in config.__dict__: + setattr(cfg, k, config.__dict__[k]) # combine cfg run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) - wandb.init(project="RWKV-LM", name=run_name + '-' + wandb.util.generate_id(), config=config, save_code=False) + wandb.init(project="RWKV-LM", name=run_name + '-' + wandb.util.generate_id(), config=cfg, save_code=False) # take over whatever gpus are on the system self.device = 'cpu' diff --git a/train.py b/train.py index a5c01e8..2e42008 100644 --- a/train.py +++ b/train.py @@ -15,14 +15,15 @@ set_seed(42) np.set_printoptions(precision=4, suppress=True, linewidth=200) logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,) -# RWKV - our new model - fastest when ctx_len is long - VRAM friendly - good performance -# MHA_rotary - usual Multi-head Attention+Rotary+GeGLU - not as good -# MHA_pro - slow (lots of tricks) - VRAM hungry - good performance -model_type = 'RWKV' # 'RWKV' or 'MHA_rotary' or 'MHA_pro' +# RWKV : our new model - fastest when ctx_len is long - VRAM friendly - good performance +# MHA_rotary : usual Multi-head Attention+Rotary+GeGLU - not as good +# MHA_shift : with time-shift - good performance +# MHA_pro : slow (lots of tricks) - VRAM hungry - very good performance +model_type = 'RWKV' # datafile = u"V:\\NLP\\text8" # datafile = u"V:\\NLP\\enwik8" -datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip +datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" datafile_encoding = 'utf-8' # datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt" # datafile_encoding = 'utf-16' @@ -60,8 +61,8 @@ class Dataset(Dataset): print('splitting token...') data = data.lower().split(' ') unique = sorted(list(set(data))) - for u in unique: - print(u, end=' ') + # for u in unique: + # print(u, end=' ') data_size, vocab_size = len(data), len(unique) print('\n\ndata has %d %ss, %d unique.' % (data_size, model_level, vocab_size)) self.stoi = { ch:i for i,ch in enumerate(unique) }