remove layernorm -> better RWKV

main
BlinkDL 4 years ago
parent 55405c57d0
commit e9fbd9bf70

@ -237,7 +237,7 @@ class RotaryMHA_Plus(nn.Module):
# The GPT Model with our blocks # The GPT Model with our blocks
######################################################################################################## ########################################################################################################
class LabelSmoothingCrossEntropy(nn.Module): # might be able to avoid nan loss class LabelSmoothingCrossEntropy(nn.Module): # can avoid nan loss
def __init__(self, smoothing=0.0): def __init__(self, smoothing=0.0):
super().__init__() super().__init__()
self.confidence = 1.0 - smoothing self.confidence = 1.0 - smoothing
@ -251,6 +251,29 @@ class LabelSmoothingCrossEntropy(nn.Module): # might be able to avoid nan loss
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * pred, dim=-1)) return torch.mean(torch.sum(-true_dist * pred, dim=-1))
class RMSNorm(nn.Module):
def __init__(self, d):
super().__init__()
self.dd = d ** (-1. / 2)
self.weight = nn.Parameter(torch.ones(d))
def forward(self, x):
norm_x = x.norm(2, dim=-1, keepdim=True)
x_normed = x / (norm_x * self.dd + 1e-12)
return self.weight * x_normed
class SimpleRMSNorm(nn.Module):
def __init__(self, d):
super().__init__()
self.dd = d ** (-1. / 2)
def forward(self, x):
norm_x = x.norm(2, dim=-1, keepdim=True)
x_normed = x / (norm_x * self.dd + 1e-12)
return x_normed
########################################################################################################
class GPTConfig: class GPTConfig:
def __init__(self, vocab_size, ctx_size, **kwargs): def __init__(self, vocab_size, ctx_size, **kwargs):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -266,6 +289,8 @@ class Block(nn.Module):
self.ln2 = nn.LayerNorm(config.n_embd) self.ln2 = nn.LayerNorm(config.n_embd)
if config.model_type == 'RWKV': if config.model_type == 'RWKV':
self.ln1 = nn.Identity() # remove first LayerNorm -> faster convergence for deep models
self.ln2 = SimpleRMSNorm(config.n_embd) # SimpleRMSNorm is good enough for RWKV -> less parameters
self.attn = RWKV_TimeMix(config) self.attn = RWKV_TimeMix(config)
self.mlp = RWKV_ChannelMix(config) self.mlp = RWKV_ChannelMix(config)
elif config.model_type == 'RotaryMHA': elif config.model_type == 'RotaryMHA':
@ -278,6 +303,7 @@ class Block(nn.Module):
def forward(self, x): def forward(self, x):
x = x + self.attn(self.ln1(x)) x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x)) x = x + self.mlp(self.ln2(x))
return x return x
class GPT(nn.Module): class GPT(nn.Module):
@ -288,7 +314,11 @@ class GPT(nn.Module):
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd) if config.model_type == 'RWKV':
self.ln_f = SimpleRMSNorm(config.n_embd) # SimpleRMSNorm is good enough for RWKV -> less parameters
else:
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.ctx_size = config.ctx_size self.ctx_size = config.ctx_size
@ -311,7 +341,7 @@ class GPT(nn.Module):
no_decay = set() no_decay = set()
whitelist_weight_modules = (nn.Linear, ) whitelist_weight_modules = (nn.Linear, )
blacklist_weight_modules = (nn.LayerNorm, nn.Embedding) blacklist_weight_modules = (RMSNorm, nn.LayerNorm, nn.Embedding)
for mn, m in self.named_modules(): for mn, m in self.named_modules():
for pn, p in m.named_parameters(): for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

@ -19,6 +19,10 @@ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s
model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' or 'MHA-Plus' model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' or 'MHA-Plus'
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip
datafile_encoding = 'utf-8'
# datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt"
# datafile_encoding = 'utf-16'
model_level = 'character' # 'character' or 'word' model_level = 'character' # 'character' or 'word'
ctx_size = 256 if model_level == 'character' else 128 ctx_size = 256 if model_level == 'character' else 128
@ -26,6 +30,10 @@ nLayers = 5
nHead = 8 nHead = 8
nEmb = 512 nEmb = 512
lr_initial = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher LR
lr_final = 2e-4
betas = (0.9, 0.99)
nepoch = 50 # just a quick test. the 'epoch' here is very short nepoch = 50 # just a quick test. the 'epoch' here is very short
nbatchsz = 64 nbatchsz = 64
epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
@ -65,7 +73,7 @@ class Dataset(Dataset):
y = torch.tensor(dix[1:], dtype=torch.long) y = torch.tensor(dix[1:], dtype=torch.long)
return x, y return x, y
train_dataset = Dataset(open(datafile, "r", encoding="utf-8").read(), model_level, ctx_size) train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_size)
######################################################################################################## ########################################################################################################
@ -74,8 +82,8 @@ model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_ty
print('model', model_type, 'total epoch', nepoch, 'batchsz', nbatchsz, 'nLayers', nLayers, 'nHead', nHead, 'nEmb', nEmb, 'len', ctx_size) print('model', model_type, 'total epoch', nepoch, 'batchsz', nbatchsz, 'nLayers', nLayers, 'nHead', nHead, 'nEmb', nEmb, 'len', ctx_size)
tconf = TrainerConfig(model_type=model_type, max_epochs=nepoch, batch_size=nbatchsz, tconf = TrainerConfig(model_type=model_type, max_epochs=nepoch, batch_size=nbatchsz,
learning_rate=6e-4 if model_type == 'RWKV' else 4e-4, betas=(0.9, 0.99), # RWKV can use higher LR learning_rate=lr_initial, lr_decay=True, lr_final=lr_final, betas=betas,
lr_decay=True, lr_final=2e-4, warmup_tokens=0, final_tokens=nepoch*len(train_dataset)*ctx_size, num_workers=0) warmup_tokens=0, final_tokens=nepoch*len(train_dataset)*ctx_size, num_workers=0)
trainer = Trainer(model, train_dataset, None, tconf) trainer = Trainer(model, train_dataset, None, tconf)
trainer.train() trainer.train()

Loading…
Cancel
Save