From 0a0eae447d9536bad99209381a5aca47fb356fe1 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Wed, 16 Feb 2022 18:02:40 +0800 Subject: [PATCH] +headQK (compatible with 2022-02-15 AI-Writer) --- src/model.py | 30 ++++++++++++++++--------- train.py | 63 +++++----------------------------------------------- 2 files changed, 26 insertions(+), 67 deletions(-) diff --git a/src/model.py b/src/model.py index c97e82f..71bc098 100644 --- a/src/model.py +++ b/src/model.py @@ -77,7 +77,6 @@ class RWKV_TimeMix(nn.Module): self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len)) self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1)) self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1)) - self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) self.time_shift = nn.ZeroPad2d((0,0,1,-1)) @@ -85,8 +84,8 @@ class RWKV_TimeMix(nn.Module): self.value = nn.Linear(config.n_embd, config.n_attn) self.receptance = nn.Linear(config.n_embd, config.n_attn) - if config.rwkv_tiny_attn > 0: - self.tiny_att = RWKV_TinyAttn(config) + # if config.rwkv_tiny_attn > 0: + # self.tiny_att = RWKV_TinyAttn(config) self.output = nn.Linear(config.n_attn, config.n_embd) @@ -102,12 +101,10 @@ class RWKV_TimeMix(nn.Module): w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) w = w[:, :, TT-1:] # w is now a circulant matrix w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :] - self.mask = self.mask[:T, :T] - w = w.masked_fill(self.mask == 0, 0) x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) - if hasattr(self, 'tiny_att'): - tiny_att = self.tiny_att(x, self.mask) + # if hasattr(self, 'tiny_att'): + # tiny_att = self.tiny_att(x, self.mask) k = self.key(x) v = self.value(x) @@ -124,8 +121,8 @@ class RWKV_TimeMix(nn.Module): rwkv = torch.sigmoid(r) * wkv / sum_k rwkv = self.output(rwkv) - if hasattr(self, 'tiny_att'): - rwkv += tiny_att + # if hasattr(self, 'tiny_att'): + # rwkv += tiny_att return rwkv * self.time_gamma[:T, :] @@ -437,6 +434,12 @@ class GPT(nn.Module): self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.head_q = nn.Linear(config.n_embd, 256) + self.head_q.scale_init = 0.01 + self.head_k = nn.Linear(config.n_embd, 256) + self.head_k.scale_init = 0.01 + self.register_buffer("copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) + self.ctx_len = config.ctx_len if self.config.model_type == 'RWKV': @@ -497,8 +500,15 @@ class GPT(nn.Module): x = self.blocks(x) x = self.ln_f(x) + + q = self.head_q(x)[:,:T,:] + k = self.head_k(x)[:,:T,:] + c = (q @ k.transpose(-2, -1)) * (1.0 / 256) + c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0) + c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float() + x = x * self.time_out[:, :T, :] # reduce confidence of early tokens - x = self.head(x) + x = self.head(x) + c loss = None if targets is not None: diff --git a/train.py b/train.py index e573ded..ab370e1 100644 --- a/train.py +++ b/train.py @@ -38,7 +38,7 @@ datafile_type = 0 # use 0 for char-level english. use 1 for chinese. only affect epoch_save_frequency = 10 # 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc. epoch_save_path = 'trained-' -batch_size = 48 # if you see "CUDA out of memory", reduce this. +batch_size = 32 # if you see "CUDA out of memory", reduce this. # if you have good GPU, increase this. # use GPU-Z to find the highest value for your VRAM. @@ -48,18 +48,18 @@ n_epoch = 100 # the 'epoch' here is actua model_level = 'character' # 'character' (recommended) or 'word' ctx_len = 256 # context length, try 512 or 1024 if you have good GPU -n_layer = 5 # try 12 for 100M, 24 for 300M +n_layer = 6 # try 12 for 100M, 24 for 300M n_head = 8 # try 12 for 100M, 16 for 300M n_embd = n_head * 64 n_attn = n_embd n_ffn = n_embd -lr_init = 8e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr. 8e-4 = 0.0008 4e-4 = 0.0004 -lr_final = 1e-5 # 1e-5 = 0.00001 +lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr. 8e-4 = 0.0008 4e-4 = 0.0004 +lr_final = 4e-5 -betas = (0.9, 0.999) if model_type == 'RWKV' else (0.9, 0.99) -eps = 1e-8 +betas = (0.9, 0.99) if model_type == 'RWKV' else (0.9, 0.99) +eps = 4e-9 weight_decay = 0 if model_type == 'RWKV' else 0.01 # wd is not useful when we have enough data epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress @@ -140,54 +140,3 @@ trainer = Trainer(model, train_dataset, None, tconf) trainer.train() torch.save(model, 'trained-' + trainer.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') - -######################################################################################################## -# Run model to generate text -######################################################################################################## - -from src.utils import sample_logits - -NUM_OF_RUNS = 5 -LENGTH_OF_EACH = 300 - -for run in range(NUM_OF_RUNS): - context = "\n" - - if model_level == 'word': - x = np.array([train_dataset.stoi[s] for s in context.strip().lower().split(' ')], dtype=np.int64) - else: - x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64) - - real_len = len(x) - if real_len < ctx_len: - x = np.pad(x, (0, ctx_len - real_len)) - print_begin = 0 - - for i in range(LENGTH_OF_EACH): - - if i == 0: - print(('-' * 80) + '\n' + context, end = '') - print_begin = real_len - - with torch.no_grad(): - xxx = torch.tensor(x[-ctx_len:], dtype=torch.long)[None,...].to("cuda:0") - out, _ = model(xxx) - pos = -1 if real_len >= ctx_len else real_len - 1 - - char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) # our special sampling method - - if real_len < ctx_len: - x[real_len] = char - else: - x = np.append(x, char) - real_len += 1 - - if i % 10 == 9 or i == LENGTH_OF_EACH-1: - if model_level == 'word': - completion = ' ' + ' '.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]]) - completion = completion.replace('\n ', '\n') - else: - completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]]) - print(completion, end = '') - print_begin = real_len - print()