misc improvements

main
BlinkDL 4 years ago
parent 6266f481da
commit d699a69169

@ -80,7 +80,7 @@ class RWKV_ChannelMix(nn.Module):
v = self.value(x)
r = self.receptance(x)
wkv = self.weight(F.gelu(k) * v)
wkv = self.weight(F.mish(k) * v) # mish is a bit better than gelu
y = torch.sigmoid(r) * wkv
return y
@ -312,6 +312,7 @@ class Block(nn.Module):
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
@ -334,6 +335,11 @@ class GPT(nn.Module):
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# if self.config.model_type == 'RWKV' and isinstance(module, nn.Linear):
# gain_layer = min(3, module.weight.shape[0] / module.weight.shape[1])
# depth_factor = min(1, 1 / math.sqrt(self.config.n_layer / 5))
# nn.init.orthogonal_(module.weight, gain = gain_layer * depth_factor) # will nan for large models
# else:
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()

@ -32,9 +32,14 @@ nLayers = 5
nHead = 8
nEmb = 512
lr_initial = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher LR
lr_initial = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr
lr_final = 2e-4
lr_initial /= math.sqrt(nLayers / 5) # lower lr for deep models; higher lr for shallow models
lr_final /= math.sqrt(nLayers / 5)
betas = (0.9, 0.99)
weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when you have enough data
nepoch = 50 # just a quick test. the 'epoch' here is very short
nbatchsz = 64
@ -87,7 +92,7 @@ model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_ty
n_layer=nLayers, n_head=nHead, n_embd=nEmb))
print('model', model_type, 'total epoch', nepoch, 'batchsz', nbatchsz, 'nLayers', nLayers, 'nHead', nHead, 'nEmb', nEmb, 'len', ctx_size)
tconf = TrainerConfig(model_type=model_type, max_epochs=nepoch, batch_size=nbatchsz,
tconf = TrainerConfig(model_type=model_type, max_epochs=nepoch, batch_size=nbatchsz, weight_decay=weight_decay,
learning_rate=lr_initial, lr_decay=True, lr_final=lr_final, betas=betas,
warmup_tokens=0, final_tokens=nepoch*len(train_dataset)*ctx_size, num_workers=0)
trainer = Trainer(model, train_dataset, None, tconf)

Loading…
Cancel
Save