+ RWKV tiny-attn and now it's great for ctx 1024 or 2048

main
BlinkDL 4 years ago
parent a9f39c112c
commit 7f391c5758

@ -13,10 +13,7 @@ logger = logging.getLogger(__name__)
# RWKV: RWKV Time-mix + RWKV Channel-mix # RWKV: RWKV Time-mix + RWKV Channel-mix
######################################################################################################## ########################################################################################################
rwkv_emb_scale = 0.4 # try 0.4 for char-level english. try 1.0 for chinese. def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
rwkv_layer_decay = 1.0 # decay weights in higher layers. try 0.5 ~ 1.0.
def RWKV_Init(module, config): # fancy initialization of every lin & emb layer in the module
for m in module.modules(): for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)): if not isinstance(m, (nn.Linear, nn.Embedding)):
continue continue
@ -27,7 +24,7 @@ def RWKV_Init(module, config): # fancy initialization of every lin & emb layer i
break break
shape = m.weight.data.shape shape = m.weight.data.shape
gain = 1.0 # positive: gain for orthogonal, negative: std for normal gain = 1.0 # positive: gain for orthogonal, negative: std for normal
scale = 1.0 # extra scale for gain scale = 1.0 # extra scale for gain
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
@ -36,12 +33,12 @@ def RWKV_Init(module, config): # fancy initialization of every lin & emb layer i
if shape[0] > shape[1]: if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1]) gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection? if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
scale = rwkv_emb_scale scale = config.rwkv_emb_scale
if isinstance(m, nn.Embedding): if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1])) gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb? if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
scale = rwkv_emb_scale scale = config.rwkv_emb_scale
if hasattr(m, 'scale_init'): if hasattr(m, 'scale_init'):
scale = m.scale_init scale = m.scale_init
@ -63,7 +60,7 @@ class RWKV_TimeMix(nn.Module):
self.n_head = config.n_head self.n_head = config.n_head
self.head_size = config.n_attn // config.n_head self.head_size = config.n_attn // config.n_head
with torch.no_grad(): # build initial time_w curves for better convergence with torch.no_grad(): # initial time_w curves for better convergence
ww = torch.zeros(config.n_head, config.ctx_len) ww = torch.zeros(config.n_head, config.ctx_len)
curve = torch.tensor([0.9 ** (config.ctx_len - 1 - i) for i in range(config.ctx_len)]) curve = torch.tensor([0.9 ** (config.ctx_len - 1 - i) for i in range(config.ctx_len)])
curve = curve * 2 + 0.7 curve = curve * 2 + 0.7
@ -91,11 +88,14 @@ class RWKV_TimeMix(nn.Module):
self.value = nn.Linear(config.n_embd, config.n_attn) self.value = nn.Linear(config.n_embd, config.n_attn)
self.receptance = nn.Linear(config.n_embd, config.n_attn) self.receptance = nn.Linear(config.n_embd, config.n_attn)
if config.rwkv_tiny_attn > 0:
self.tiny_att = RWKV_TinyAttn(config)
self.output = nn.Linear(config.n_attn, config.n_embd) self.output = nn.Linear(config.n_attn, config.n_embd)
self.key.scale_init = 0 self.key.scale_init = 0
self.receptance.scale_init = 0 self.receptance.scale_init = 0
self.output.scale_init = 1 / pow(1+layer_id, rwkv_layer_decay) # decay weight in higher layers self.output.scale_init = 1 / pow(1+layer_id, config.rwkv_layer_decay) # reduce initial weight in higher layers
def forward(self, x): def forward(self, x):
B, T, C = x.size() B, T, C = x.size()
@ -105,14 +105,18 @@ class RWKV_TimeMix(nn.Module):
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT-1:] # w is now a circulant matrix w = w[:, :, TT-1:] # w is now a circulant matrix
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :] w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
w = w.masked_fill(self.mask[:T, :T] == 0, 0) self.mask = self.mask[:T, :T]
w = w.masked_fill(self.mask == 0, 0)
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
if hasattr(self, 'tiny_att'):
tiny_att = self.tiny_att(x, self.mask)
k = self.key(x) k = self.key(x)
v = self.value(x) v = self.value(x)
r = self.receptance(x) r = self.receptance(x)
k = torch.clamp(k, max=30) # clamp extreme values k = torch.clamp(k, max=30) # clamp extreme values. e^30 = 10^13
k = torch.exp(k) k = torch.exp(k)
sum_k = torch.cumsum(k, dim=1) sum_k = torch.cumsum(k, dim=1)
@ -122,7 +126,11 @@ class RWKV_TimeMix(nn.Module):
rwkv = torch.sigmoid(r) * wkv / sum_k rwkv = torch.sigmoid(r) * wkv / sum_k
return self.output(rwkv) * self.time_gamma[:T, :] rwkv = self.output(rwkv)
if hasattr(self, 'tiny_att'):
rwkv += tiny_att
return rwkv * self.time_gamma[:T, :]
class RWKV_ChannelMix(nn.Module): class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id): def __init__(self, config, layer_id):
@ -130,14 +138,14 @@ class RWKV_ChannelMix(nn.Module):
self.layer_id = layer_id self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,0)) self.time_shift = nn.ZeroPad2d((0,0,1,0))
hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of R hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of receptance gating
self.key = nn.Linear(config.n_embd, hidden_sz) self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz) self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd) self.weight = nn.Linear(hidden_sz, config.n_embd)
self.receptance = nn.Linear(config.n_embd, config.n_embd) self.receptance = nn.Linear(config.n_embd, config.n_embd)
self.receptance.scale_init = 0 self.receptance.scale_init = 0
self.weight.scale_init = 1 / pow(1+layer_id, rwkv_layer_decay) # decay weight in higher layers self.weight.scale_init = 1 / pow(1+layer_id, config.rwkv_layer_decay) # reduce initial weight in higher layers
def forward(self, x): def forward(self, x):
B, T, C = x.size() B, T, C = x.size()
@ -147,12 +155,42 @@ class RWKV_ChannelMix(nn.Module):
v = self.value(x) v = self.value(x)
r = self.receptance(x) r = self.receptance(x)
wkv = self.weight(F.mish(k) * v) # seems mish is a bit better than gelu wkv = self.weight(F.mish(k) * v) # i find mish is a bit better than gelu
rwkv = torch.sigmoid(r) * wkv rwkv = torch.sigmoid(r) * wkv
return rwkv return rwkv
class RWKV_TinyAttn(nn.Module): # extra tiny attention
def __init__(self, config):
super().__init__()
self.d_attn = config.rwkv_tiny_attn
self.n_head = config.rwkv_tiny_head
self.head_size = self.d_attn // self.n_head
self.qkv = nn.Linear(config.n_embd, self.d_attn * 3)
self.out = nn.Linear(self.d_attn, config.n_embd)
def forward(self, x, mask):
B, T, C = x.size()
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim = -1)
if self.n_head > 1:
q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
qk = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_size)) # (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
qk = qk.masked_fill(mask == 0, float('-inf'))
qk = F.softmax(qk, dim = -1)
qkv = qk @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
if self.n_head > 1:
qkv = qkv.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
return self.out(qkv)
######################################################################################################## ########################################################################################################
# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN # MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
######################################################################################################## ########################################################################################################
@ -182,7 +220,7 @@ def rotate_half(x):
@torch.jit.script @torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin): def apply_rotary_pos_emb(q, k, cos, sin):
cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:] cos, sin = cos[...,:q.shape[-2],:], sin[...,:q.shape[-2],:]
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
class MHA_rotary(nn.Module): class MHA_rotary(nn.Module):

@ -2,8 +2,7 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
######################################################################################################## ########################################################################################################
import os, sys, time, math, random, json, datetime import os, sys, time, math, random, json, datetime, logging
import logging
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -28,9 +27,11 @@ datafile_encoding = 'utf-8'
# datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt" # datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt"
# datafile_encoding = 'utf-16' # datafile_encoding = 'utf-16'
model_level = 'character' # 'character' or 'word' datafile_type = 0 # use 0 for char-level english. use 1 for chinese. only affects some RWKV hyperparametrs
ctx_len = 256 # context length model_level = 'character' # 'character' (recommended) or 'word'
ctx_len = 256 # context length
n_layer = 5 n_layer = 5
n_head = 8 n_head = 8
n_embd = n_head * 64 n_embd = n_head * 64
@ -40,14 +41,21 @@ n_ffn = n_embd
batch_size = 64 batch_size = 64
n_epoch = 50 # the 'epoch' here is actually very short (and of fixed length) n_epoch = 50 # the 'epoch' here is actually very short (and of fixed length)
lr_init = 8e-4 if model_type == 'RWKV' else 4e-4 # seems RWKV can use higher lr lr_init = 8e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr
lr_final = 2e-4 lr_final = 2e-4
betas = (0.9, 0.999) if model_type == 'RWKV' else (0.9, 0.99) betas = (0.9, 0.999) if model_type == 'RWKV' else (0.9, 0.99)
eps = 1e-8 eps = 1e-8
weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when we have enough data weight_decay = 0 if model_type == 'RWKV' else 0.01 # wd is not useful when we have enough data
epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
######## special hyperparameters for RWKV model ########
rwkv_layer_decay = 1.0 # reduce initial weight in higher layers. try 0.5 ~ 1.0
rwkv_emb_scale = 0.4 if datafile_type == 0 else 0.8 # use 0.4 for char-level english, 0.8 for chinese
rwkv_tiny_attn = 64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english
rwkv_tiny_head = 1 # 1 is good enough
######################################################################################################## ########################################################################################################
# Load data # Load data
######################################################################################################## ########################################################################################################
@ -94,6 +102,7 @@ train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(),
######################################################################################################## ########################################################################################################
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type, model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
rwkv_emb_scale=rwkv_emb_scale, rwkv_layer_decay=rwkv_layer_decay, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head,
n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn)) n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn))
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn) print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn)

Loading…
Cancel
Save