|
|
|
|
@ -13,10 +13,7 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
rwkv_emb_scale = 0.4 # try 0.4 for char-level english. try 1.0 for chinese.
|
|
|
|
|
rwkv_layer_decay = 1.0 # decay weights in higher layers. try 0.5 ~ 1.0.
|
|
|
|
|
|
|
|
|
|
def RWKV_Init(module, config): # fancy initialization of every lin & emb layer in the module
|
|
|
|
|
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
|
|
|
|
|
for m in module.modules():
|
|
|
|
|
if not isinstance(m, (nn.Linear, nn.Embedding)):
|
|
|
|
|
continue
|
|
|
|
|
@ -27,7 +24,7 @@ def RWKV_Init(module, config): # fancy initialization of every lin & emb layer i
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
shape = m.weight.data.shape
|
|
|
|
|
gain = 1.0 # positive: gain for orthogonal, negative: std for normal
|
|
|
|
|
gain = 1.0 # positive: gain for orthogonal, negative: std for normal
|
|
|
|
|
scale = 1.0 # extra scale for gain
|
|
|
|
|
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
|
@ -36,12 +33,12 @@ def RWKV_Init(module, config): # fancy initialization of every lin & emb layer i
|
|
|
|
|
if shape[0] > shape[1]:
|
|
|
|
|
gain = math.sqrt(shape[0] / shape[1])
|
|
|
|
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
|
|
|
|
|
scale = rwkv_emb_scale
|
|
|
|
|
scale = config.rwkv_emb_scale
|
|
|
|
|
|
|
|
|
|
if isinstance(m, nn.Embedding):
|
|
|
|
|
gain = math.sqrt(max(shape[0], shape[1]))
|
|
|
|
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
|
|
|
|
|
scale = rwkv_emb_scale
|
|
|
|
|
scale = config.rwkv_emb_scale
|
|
|
|
|
|
|
|
|
|
if hasattr(m, 'scale_init'):
|
|
|
|
|
scale = m.scale_init
|
|
|
|
|
@ -63,7 +60,7 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
self.n_head = config.n_head
|
|
|
|
|
self.head_size = config.n_attn // config.n_head
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): # build initial time_w curves for better convergence
|
|
|
|
|
with torch.no_grad(): # initial time_w curves for better convergence
|
|
|
|
|
ww = torch.zeros(config.n_head, config.ctx_len)
|
|
|
|
|
curve = torch.tensor([0.9 ** (config.ctx_len - 1 - i) for i in range(config.ctx_len)])
|
|
|
|
|
curve = curve * 2 + 0.7
|
|
|
|
|
@ -91,11 +88,14 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
|
|
|
|
|
if config.rwkv_tiny_attn > 0:
|
|
|
|
|
self.tiny_att = RWKV_TinyAttn(config)
|
|
|
|
|
|
|
|
|
|
self.output = nn.Linear(config.n_attn, config.n_embd)
|
|
|
|
|
|
|
|
|
|
self.key.scale_init = 0
|
|
|
|
|
self.receptance.scale_init = 0
|
|
|
|
|
self.output.scale_init = 1 / pow(1+layer_id, rwkv_layer_decay) # decay weight in higher layers
|
|
|
|
|
self.output.scale_init = 1 / pow(1+layer_id, config.rwkv_layer_decay) # reduce initial weight in higher layers
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
@ -105,14 +105,18 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
|
|
|
|
|
w = w[:, :, TT-1:] # w is now a circulant matrix
|
|
|
|
|
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
|
|
|
|
|
w = w.masked_fill(self.mask[:T, :T] == 0, 0)
|
|
|
|
|
self.mask = self.mask[:T, :T]
|
|
|
|
|
w = w.masked_fill(self.mask == 0, 0)
|
|
|
|
|
|
|
|
|
|
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
|
|
|
|
|
if hasattr(self, 'tiny_att'):
|
|
|
|
|
tiny_att = self.tiny_att(x, self.mask)
|
|
|
|
|
|
|
|
|
|
k = self.key(x)
|
|
|
|
|
v = self.value(x)
|
|
|
|
|
r = self.receptance(x)
|
|
|
|
|
|
|
|
|
|
k = torch.clamp(k, max=30) # clamp extreme values
|
|
|
|
|
k = torch.clamp(k, max=30) # clamp extreme values. e^30 = 10^13
|
|
|
|
|
k = torch.exp(k)
|
|
|
|
|
sum_k = torch.cumsum(k, dim=1)
|
|
|
|
|
|
|
|
|
|
@ -122,7 +126,11 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
|
|
|
|
|
rwkv = torch.sigmoid(r) * wkv / sum_k
|
|
|
|
|
|
|
|
|
|
return self.output(rwkv) * self.time_gamma[:T, :]
|
|
|
|
|
rwkv = self.output(rwkv)
|
|
|
|
|
if hasattr(self, 'tiny_att'):
|
|
|
|
|
rwkv += tiny_att
|
|
|
|
|
|
|
|
|
|
return rwkv * self.time_gamma[:T, :]
|
|
|
|
|
|
|
|
|
|
class RWKV_ChannelMix(nn.Module):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
@ -130,14 +138,14 @@ class RWKV_ChannelMix(nn.Module):
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
|
|
|
|
|
hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of R
|
|
|
|
|
hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of receptance gating
|
|
|
|
|
self.key = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
|
self.value = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
|
self.weight = nn.Linear(hidden_sz, config.n_embd)
|
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
|
|
|
|
|
self.receptance.scale_init = 0
|
|
|
|
|
self.weight.scale_init = 1 / pow(1+layer_id, rwkv_layer_decay) # decay weight in higher layers
|
|
|
|
|
self.weight.scale_init = 1 / pow(1+layer_id, config.rwkv_layer_decay) # reduce initial weight in higher layers
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
@ -147,12 +155,42 @@ class RWKV_ChannelMix(nn.Module):
|
|
|
|
|
v = self.value(x)
|
|
|
|
|
r = self.receptance(x)
|
|
|
|
|
|
|
|
|
|
wkv = self.weight(F.mish(k) * v) # seems mish is a bit better than gelu
|
|
|
|
|
wkv = self.weight(F.mish(k) * v) # i find mish is a bit better than gelu
|
|
|
|
|
|
|
|
|
|
rwkv = torch.sigmoid(r) * wkv
|
|
|
|
|
|
|
|
|
|
return rwkv
|
|
|
|
|
|
|
|
|
|
class RWKV_TinyAttn(nn.Module): # extra tiny attention
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.d_attn = config.rwkv_tiny_attn
|
|
|
|
|
self.n_head = config.rwkv_tiny_head
|
|
|
|
|
self.head_size = self.d_attn // self.n_head
|
|
|
|
|
|
|
|
|
|
self.qkv = nn.Linear(config.n_embd, self.d_attn * 3)
|
|
|
|
|
self.out = nn.Linear(self.d_attn, config.n_embd)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, mask):
|
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
qkv = self.qkv(x)
|
|
|
|
|
q, k, v = qkv.chunk(3, dim = -1)
|
|
|
|
|
|
|
|
|
|
if self.n_head > 1:
|
|
|
|
|
q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
|
|
|
|
|
|
|
|
|
qk = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_size)) # (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
|
|
|
|
|
qk = qk.masked_fill(mask == 0, float('-inf'))
|
|
|
|
|
qk = F.softmax(qk, dim = -1)
|
|
|
|
|
qkv = qk @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
|
|
|
|
|
|
|
|
|
if self.n_head > 1:
|
|
|
|
|
qkv = qkv.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
|
|
|
|
|
|
|
|
|
return self.out(qkv)
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
|
|
|
|
|
########################################################################################################
|
|
|
|
|
@ -182,7 +220,7 @@ def rotate_half(x):
|
|
|
|
|
|
|
|
|
|
@torch.jit.script
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin):
|
|
|
|
|
cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:]
|
|
|
|
|
cos, sin = cos[...,:q.shape[-2],:], sin[...,:q.shape[-2],:]
|
|
|
|
|
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
|
|
|
|
|
|
|
|
|
class MHA_rotary(nn.Module):
|
|
|
|
|
@ -223,7 +261,7 @@ class MHA_rotary(nn.Module):
|
|
|
|
|
cos, sin = self.rotary_emb(q, seq_len=T)
|
|
|
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
|
|
|
|
|
q = torch.cat((q, query_pass), dim=-1)
|
|
|
|
|
k = torch.cat((k, key_pass), dim=-1)
|
|
|
|
|
k = torch.cat((k, key_pass), dim=-1)
|
|
|
|
|
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
|
|
|
|
|
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
|
|
|
|
|
|