|
|
|
|
@ -47,9 +47,10 @@ n_epoch = 100 # the 'epoch' here is actua
|
|
|
|
|
|
|
|
|
|
model_level = 'character' # 'character' (recommended) or 'word'
|
|
|
|
|
|
|
|
|
|
ctx_len = 256 # context length
|
|
|
|
|
n_layer = 5
|
|
|
|
|
n_head = 8
|
|
|
|
|
ctx_len = 256 # context length, try 512 or 1024 if you have good GPU
|
|
|
|
|
n_layer = 5 # try 12 for 100M, 24 for 300M
|
|
|
|
|
n_head = 8 # try 12 for 100M, 16 for 300M
|
|
|
|
|
|
|
|
|
|
n_embd = n_head * 64
|
|
|
|
|
n_attn = n_embd
|
|
|
|
|
n_ffn = n_embd
|
|
|
|
|
@ -65,7 +66,7 @@ epoch_length_fixed = 10000 # make an 'epoch' very short
|
|
|
|
|
|
|
|
|
|
######## special hyperparameters for RWKV model ########
|
|
|
|
|
rwkv_emb_scale = 0.4 # scale of initial embedding. 0.4 is a good choice
|
|
|
|
|
rwkv_tiny_attn = 64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english
|
|
|
|
|
rwkv_tiny_attn = 0#64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english
|
|
|
|
|
rwkv_tiny_head = 1 # 1 is good enough. 8 is slow
|
|
|
|
|
# n_side_proj = 512 # extra 'side projection', quite useful for BPE models
|
|
|
|
|
|
|
|
|
|
@ -127,6 +128,9 @@ model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_typ
|
|
|
|
|
rwkv_emb_scale=rwkv_emb_scale, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head,
|
|
|
|
|
n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn))
|
|
|
|
|
|
|
|
|
|
# load a trained model
|
|
|
|
|
# model.load_state_dict(torch.load('trained-xxx.pth').state_dict())
|
|
|
|
|
|
|
|
|
|
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn)
|
|
|
|
|
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay,
|
|
|
|
|
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps,
|
|
|
|
|
|