From a19be54bf52a1f4f6d2cb302db2a2627d0e61e1f Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Thu, 28 Oct 2021 15:24:04 +0800 Subject: [PATCH] no message --- train.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 1e0b2bd..4782cf2 100644 --- a/train.py +++ b/train.py @@ -47,9 +47,10 @@ n_epoch = 100 # the 'epoch' here is actua model_level = 'character' # 'character' (recommended) or 'word' -ctx_len = 256 # context length -n_layer = 5 -n_head = 8 +ctx_len = 256 # context length, try 512 or 1024 if you have good GPU +n_layer = 5 # try 12 for 100M, 24 for 300M +n_head = 8 # try 12 for 100M, 16 for 300M + n_embd = n_head * 64 n_attn = n_embd n_ffn = n_embd @@ -65,9 +66,9 @@ epoch_length_fixed = 10000 # make an 'epoch' very short ######## special hyperparameters for RWKV model ######## rwkv_emb_scale = 0.4 # scale of initial embedding. 0.4 is a good choice -rwkv_tiny_attn = 64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english +rwkv_tiny_attn = 0#64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english rwkv_tiny_head = 1 # 1 is good enough. 8 is slow -# n_side_proj = 512 # extra 'side projection', quite useful for BPE models +# n_side_proj = 512 # extra 'side projection', quite useful for BPE models ######################################################################################################## # Load data @@ -127,6 +128,9 @@ model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_typ rwkv_emb_scale=rwkv_emb_scale, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head, n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn)) +# load a trained model +# model.load_state_dict(torch.load('trained-xxx.pth').state_dict()) + print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn) tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay, learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps,