diff --git a/src/model.py b/src/model.py index 2d9f118..2f9e79c 100644 --- a/src/model.py +++ b/src/model.py @@ -13,10 +13,7 @@ logger = logging.getLogger(__name__) # RWKV: RWKV Time-mix + RWKV Channel-mix ######################################################################################################## -rwkv_emb_scale = 0.4 # try 0.4 for char-level english. try 1.0 for chinese. -rwkv_layer_decay = 1.0 # decay weights in higher layers. try 0.5 ~ 1.0. - -def RWKV_Init(module, config): # fancy initialization of every lin & emb layer in the module +def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module for m in module.modules(): if not isinstance(m, (nn.Linear, nn.Embedding)): continue @@ -27,7 +24,7 @@ def RWKV_Init(module, config): # fancy initialization of every lin & emb layer i break shape = m.weight.data.shape - gain = 1.0 # positive: gain for orthogonal, negative: std for normal + gain = 1.0 # positive: gain for orthogonal, negative: std for normal scale = 1.0 # extra scale for gain if isinstance(m, nn.Linear): @@ -36,12 +33,12 @@ def RWKV_Init(module, config): # fancy initialization of every lin & emb layer i if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection? - scale = rwkv_emb_scale + scale = config.rwkv_emb_scale if isinstance(m, nn.Embedding): gain = math.sqrt(max(shape[0], shape[1])) if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb? - scale = rwkv_emb_scale + scale = config.rwkv_emb_scale if hasattr(m, 'scale_init'): scale = m.scale_init @@ -63,7 +60,7 @@ class RWKV_TimeMix(nn.Module): self.n_head = config.n_head self.head_size = config.n_attn // config.n_head - with torch.no_grad(): # build initial time_w curves for better convergence + with torch.no_grad(): # initial time_w curves for better convergence ww = torch.zeros(config.n_head, config.ctx_len) curve = torch.tensor([0.9 ** (config.ctx_len - 1 - i) for i in range(config.ctx_len)]) curve = curve * 2 + 0.7 @@ -91,11 +88,14 @@ class RWKV_TimeMix(nn.Module): self.value = nn.Linear(config.n_embd, config.n_attn) self.receptance = nn.Linear(config.n_embd, config.n_attn) + if config.rwkv_tiny_attn > 0: + self.tiny_att = RWKV_TinyAttn(config) + self.output = nn.Linear(config.n_attn, config.n_embd) self.key.scale_init = 0 self.receptance.scale_init = 0 - self.output.scale_init = 1 / pow(1+layer_id, rwkv_layer_decay) # decay weight in higher layers + self.output.scale_init = 1 / pow(1+layer_id, config.rwkv_layer_decay) # reduce initial weight in higher layers def forward(self, x): B, T, C = x.size() @@ -105,14 +105,18 @@ class RWKV_TimeMix(nn.Module): w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) w = w[:, :, TT-1:] # w is now a circulant matrix w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :] - w = w.masked_fill(self.mask[:T, :T] == 0, 0) + self.mask = self.mask[:T, :T] + w = w.masked_fill(self.mask == 0, 0) x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) + if hasattr(self, 'tiny_att'): + tiny_att = self.tiny_att(x, self.mask) + k = self.key(x) v = self.value(x) r = self.receptance(x) - k = torch.clamp(k, max=30) # clamp extreme values + k = torch.clamp(k, max=30) # clamp extreme values. e^30 = 10^13 k = torch.exp(k) sum_k = torch.cumsum(k, dim=1) @@ -122,7 +126,11 @@ class RWKV_TimeMix(nn.Module): rwkv = torch.sigmoid(r) * wkv / sum_k - return self.output(rwkv) * self.time_gamma[:T, :] + rwkv = self.output(rwkv) + if hasattr(self, 'tiny_att'): + rwkv += tiny_att + + return rwkv * self.time_gamma[:T, :] class RWKV_ChannelMix(nn.Module): def __init__(self, config, layer_id): @@ -130,14 +138,14 @@ class RWKV_ChannelMix(nn.Module): self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0,0,1,0)) - hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of R + hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of receptance gating self.key = nn.Linear(config.n_embd, hidden_sz) self.value = nn.Linear(config.n_embd, hidden_sz) self.weight = nn.Linear(hidden_sz, config.n_embd) self.receptance = nn.Linear(config.n_embd, config.n_embd) self.receptance.scale_init = 0 - self.weight.scale_init = 1 / pow(1+layer_id, rwkv_layer_decay) # decay weight in higher layers + self.weight.scale_init = 1 / pow(1+layer_id, config.rwkv_layer_decay) # reduce initial weight in higher layers def forward(self, x): B, T, C = x.size() @@ -147,12 +155,42 @@ class RWKV_ChannelMix(nn.Module): v = self.value(x) r = self.receptance(x) - wkv = self.weight(F.mish(k) * v) # seems mish is a bit better than gelu + wkv = self.weight(F.mish(k) * v) # i find mish is a bit better than gelu rwkv = torch.sigmoid(r) * wkv return rwkv +class RWKV_TinyAttn(nn.Module): # extra tiny attention + def __init__(self, config): + super().__init__() + self.d_attn = config.rwkv_tiny_attn + self.n_head = config.rwkv_tiny_head + self.head_size = self.d_attn // self.n_head + + self.qkv = nn.Linear(config.n_embd, self.d_attn * 3) + self.out = nn.Linear(self.d_attn, config.n_embd) + + def forward(self, x, mask): + B, T, C = x.size() + qkv = self.qkv(x) + q, k, v = qkv.chunk(3, dim = -1) + + if self.n_head > 1: + q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) + k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) + v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) + + qk = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_size)) # (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) + qk = qk.masked_fill(mask == 0, float('-inf')) + qk = F.softmax(qk, dim = -1) + qkv = qk @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) + + if self.n_head > 1: + qkv = qkv.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) + + return self.out(qkv) + ######################################################################################################## # MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN ######################################################################################################## @@ -182,7 +220,7 @@ def rotate_half(x): @torch.jit.script def apply_rotary_pos_emb(q, k, cos, sin): - cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:] + cos, sin = cos[...,:q.shape[-2],:], sin[...,:q.shape[-2],:] return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) class MHA_rotary(nn.Module): @@ -223,7 +261,7 @@ class MHA_rotary(nn.Module): cos, sin = self.rotary_emb(q, seq_len=T) q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding q = torch.cat((q, query_pass), dim=-1) - k = torch.cat((k, key_pass), dim=-1) + k = torch.cat((k, key_pass), dim=-1) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask diff --git a/train.py b/train.py index f4b6a87..53a639f 100644 --- a/train.py +++ b/train.py @@ -2,8 +2,7 @@ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## -import os, sys, time, math, random, json, datetime -import logging +import os, sys, time, math, random, json, datetime, logging import numpy as np import torch from torch.utils.data import Dataset @@ -28,9 +27,11 @@ datafile_encoding = 'utf-8' # datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt" # datafile_encoding = 'utf-16' -model_level = 'character' # 'character' or 'word' +datafile_type = 0 # use 0 for char-level english. use 1 for chinese. only affects some RWKV hyperparametrs -ctx_len = 256 # context length +model_level = 'character' # 'character' (recommended) or 'word' + +ctx_len = 256 # context length n_layer = 5 n_head = 8 n_embd = n_head * 64 @@ -40,14 +41,21 @@ n_ffn = n_embd batch_size = 64 n_epoch = 50 # the 'epoch' here is actually very short (and of fixed length) -lr_init = 8e-4 if model_type == 'RWKV' else 4e-4 # seems RWKV can use higher lr +lr_init = 8e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr lr_final = 2e-4 betas = (0.9, 0.999) if model_type == 'RWKV' else (0.9, 0.99) eps = 1e-8 -weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when we have enough data +weight_decay = 0 if model_type == 'RWKV' else 0.01 # wd is not useful when we have enough data + epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress +######## special hyperparameters for RWKV model ######## +rwkv_layer_decay = 1.0 # reduce initial weight in higher layers. try 0.5 ~ 1.0 +rwkv_emb_scale = 0.4 if datafile_type == 0 else 0.8 # use 0.4 for char-level english, 0.8 for chinese +rwkv_tiny_attn = 64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english +rwkv_tiny_head = 1 # 1 is good enough + ######################################################################################################## # Load data ######################################################################################################## @@ -94,6 +102,7 @@ train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), ######################################################################################################## model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type, + rwkv_emb_scale=rwkv_emb_scale, rwkv_layer_decay=rwkv_layer_decay, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head, n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn)) print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn)