+headQK (compatible with 2022-02-15 AI-Writer)

main
BlinkDL 4 years ago
parent b48aa1d430
commit 0a0eae447d

@ -77,7 +77,6 @@ class RWKV_TimeMix(nn.Module):
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len)) self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1)) self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1)) self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.time_shift = nn.ZeroPad2d((0,0,1,-1)) self.time_shift = nn.ZeroPad2d((0,0,1,-1))
@ -85,8 +84,8 @@ class RWKV_TimeMix(nn.Module):
self.value = nn.Linear(config.n_embd, config.n_attn) self.value = nn.Linear(config.n_embd, config.n_attn)
self.receptance = nn.Linear(config.n_embd, config.n_attn) self.receptance = nn.Linear(config.n_embd, config.n_attn)
if config.rwkv_tiny_attn > 0: # if config.rwkv_tiny_attn > 0:
self.tiny_att = RWKV_TinyAttn(config) # self.tiny_att = RWKV_TinyAttn(config)
self.output = nn.Linear(config.n_attn, config.n_embd) self.output = nn.Linear(config.n_attn, config.n_embd)
@ -102,12 +101,10 @@ class RWKV_TimeMix(nn.Module):
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT-1:] # w is now a circulant matrix w = w[:, :, TT-1:] # w is now a circulant matrix
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :] w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
self.mask = self.mask[:T, :T]
w = w.masked_fill(self.mask == 0, 0)
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
if hasattr(self, 'tiny_att'): # if hasattr(self, 'tiny_att'):
tiny_att = self.tiny_att(x, self.mask) # tiny_att = self.tiny_att(x, self.mask)
k = self.key(x) k = self.key(x)
v = self.value(x) v = self.value(x)
@ -124,8 +121,8 @@ class RWKV_TimeMix(nn.Module):
rwkv = torch.sigmoid(r) * wkv / sum_k rwkv = torch.sigmoid(r) * wkv / sum_k
rwkv = self.output(rwkv) rwkv = self.output(rwkv)
if hasattr(self, 'tiny_att'): # if hasattr(self, 'tiny_att'):
rwkv += tiny_att # rwkv += tiny_att
return rwkv * self.time_gamma[:T, :] return rwkv * self.time_gamma[:T, :]
@ -437,6 +434,12 @@ class GPT(nn.Module):
self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.head_q = nn.Linear(config.n_embd, 256)
self.head_q.scale_init = 0.01
self.head_k = nn.Linear(config.n_embd, 256)
self.head_k.scale_init = 0.01
self.register_buffer("copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len self.ctx_len = config.ctx_len
if self.config.model_type == 'RWKV': if self.config.model_type == 'RWKV':
@ -497,8 +500,15 @@ class GPT(nn.Module):
x = self.blocks(x) x = self.blocks(x)
x = self.ln_f(x) x = self.ln_f(x)
q = self.head_q(x)[:,:T,:]
k = self.head_k(x)[:,:T,:]
c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
x = x * self.time_out[:, :T, :] # reduce confidence of early tokens x = x * self.time_out[:, :T, :] # reduce confidence of early tokens
x = self.head(x) x = self.head(x) + c
loss = None loss = None
if targets is not None: if targets is not None:

@ -38,7 +38,7 @@ datafile_type = 0 # use 0 for char-level english. use 1 for chinese. only affect
epoch_save_frequency = 10 # 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc. epoch_save_frequency = 10 # 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc.
epoch_save_path = 'trained-' epoch_save_path = 'trained-'
batch_size = 48 # if you see "CUDA out of memory", reduce this. batch_size = 32 # if you see "CUDA out of memory", reduce this.
# if you have good GPU, increase this. # if you have good GPU, increase this.
# use GPU-Z to find the highest value for your VRAM. # use GPU-Z to find the highest value for your VRAM.
@ -48,18 +48,18 @@ n_epoch = 100 # the 'epoch' here is actua
model_level = 'character' # 'character' (recommended) or 'word' model_level = 'character' # 'character' (recommended) or 'word'
ctx_len = 256 # context length, try 512 or 1024 if you have good GPU ctx_len = 256 # context length, try 512 or 1024 if you have good GPU
n_layer = 5 # try 12 for 100M, 24 for 300M n_layer = 6 # try 12 for 100M, 24 for 300M
n_head = 8 # try 12 for 100M, 16 for 300M n_head = 8 # try 12 for 100M, 16 for 300M
n_embd = n_head * 64 n_embd = n_head * 64
n_attn = n_embd n_attn = n_embd
n_ffn = n_embd n_ffn = n_embd
lr_init = 8e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr. 8e-4 = 0.0008 4e-4 = 0.0004 lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr. 8e-4 = 0.0008 4e-4 = 0.0004
lr_final = 1e-5 # 1e-5 = 0.00001 lr_final = 4e-5
betas = (0.9, 0.999) if model_type == 'RWKV' else (0.9, 0.99) betas = (0.9, 0.99) if model_type == 'RWKV' else (0.9, 0.99)
eps = 1e-8 eps = 4e-9
weight_decay = 0 if model_type == 'RWKV' else 0.01 # wd is not useful when we have enough data weight_decay = 0 if model_type == 'RWKV' else 0.01 # wd is not useful when we have enough data
epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
@ -140,54 +140,3 @@ trainer = Trainer(model, train_dataset, None, tconf)
trainer.train() trainer.train()
torch.save(model, 'trained-' + trainer.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') torch.save(model, 'trained-' + trainer.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')
########################################################################################################
# Run model to generate text
########################################################################################################
from src.utils import sample_logits
NUM_OF_RUNS = 5
LENGTH_OF_EACH = 300
for run in range(NUM_OF_RUNS):
context = "\n"
if model_level == 'word':
x = np.array([train_dataset.stoi[s] for s in context.strip().lower().split(' ')], dtype=np.int64)
else:
x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64)
real_len = len(x)
if real_len < ctx_len:
x = np.pad(x, (0, ctx_len - real_len))
print_begin = 0
for i in range(LENGTH_OF_EACH):
if i == 0:
print(('-' * 80) + '\n' + context, end = '')
print_begin = real_len
with torch.no_grad():
xxx = torch.tensor(x[-ctx_len:], dtype=torch.long)[None,...].to("cuda:0")
out, _ = model(xxx)
pos = -1 if real_len >= ctx_len else real_len - 1
char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) # our special sampling method
if real_len < ctx_len:
x[real_len] = char
else:
x = np.append(x, char)
real_len += 1
if i % 10 == 9 or i == LENGTH_OF_EACH-1:
if model_level == 'word':
completion = ' ' + ' '.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]])
completion = completion.replace('\n ', '\n')
else:
completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]])
print(completion, end = '')
print_begin = real_len
print()

Loading…
Cancel
Save