diff --git a/src/model.py b/src/model.py index d5f36a2..20131a0 100644 --- a/src/model.py +++ b/src/model.py @@ -80,7 +80,7 @@ class RWKV_ChannelMix(nn.Module): v = self.value(x) r = self.receptance(x) - wkv = self.weight(F.gelu(k) * v) + wkv = self.weight(F.mish(k) * v) # mish is a bit better than gelu y = torch.sigmoid(r) * wkv return y @@ -312,6 +312,7 @@ class Block(nn.Module): class GPT(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) @@ -334,6 +335,11 @@ class GPT(nn.Module): def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): + # if self.config.model_type == 'RWKV' and isinstance(module, nn.Linear): + # gain_layer = min(3, module.weight.shape[0] / module.weight.shape[1]) + # depth_factor = min(1, 1 / math.sqrt(self.config.n_layer / 5)) + # nn.init.orthogonal_(module.weight, gain = gain_layer * depth_factor) # will nan for large models + # else: module.weight.data.normal_(mean=0.0, std=0.01) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() diff --git a/train.py b/train.py index 190a18e..bb4498c 100644 --- a/train.py +++ b/train.py @@ -32,9 +32,14 @@ nLayers = 5 nHead = 8 nEmb = 512 -lr_initial = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher LR +lr_initial = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr lr_final = 2e-4 + +lr_initial /= math.sqrt(nLayers / 5) # lower lr for deep models; higher lr for shallow models +lr_final /= math.sqrt(nLayers / 5) + betas = (0.9, 0.99) +weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when you have enough data nepoch = 50 # just a quick test. the 'epoch' here is very short nbatchsz = 64 @@ -87,7 +92,7 @@ model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_ty n_layer=nLayers, n_head=nHead, n_embd=nEmb)) print('model', model_type, 'total epoch', nepoch, 'batchsz', nbatchsz, 'nLayers', nLayers, 'nHead', nHead, 'nEmb', nEmb, 'len', ctx_size) -tconf = TrainerConfig(model_type=model_type, max_epochs=nepoch, batch_size=nbatchsz, +tconf = TrainerConfig(model_type=model_type, max_epochs=nepoch, batch_size=nbatchsz, weight_decay=weight_decay, learning_rate=lr_initial, lr_decay=True, lr_final=lr_final, betas=betas, warmup_tokens=0, final_tokens=nepoch*len(train_dataset)*ctx_size, num_workers=0) trainer = Trainer(model, train_dataset, None, tconf)