misc improvements

main
BlinkDL 4 years ago
parent ef29f4b9e8
commit c675b47705

@ -51,7 +51,7 @@ class RWKV_TimeMix(nn.Module):
v = self.value(x)
r = self.receptance(x)
k = torch.clamp(k, max=30) # clamp crazy values
k = torch.clamp(k, max=30) # clamp extreme values
k = torch.exp(k)
sum_k = torch.cumsum(k, dim=1)
@ -300,6 +300,8 @@ class Block(nn.Module):
self.ln2 = nn.LayerNorm(config.n_embd)
if config.model_type == 'RWKV':
# self.ln1 = FixedNorm(config.n_embd)
# self.ln2 = FixedNorm(config.n_embd)
self.attn = RWKV_TimeMix(config, layer_id)
self.mlp = RWKV_ChannelMix(config, layer_id)
@ -332,6 +334,7 @@ class GPT(nn.Module):
self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.ctx_len = config.ctx_len
@ -345,30 +348,31 @@ class GPT(nn.Module):
ww[k] *= math.sqrt(self.config.vocab_size)
else:
ww[k] *= math.sqrt(self.config.n_embd)
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese
elif 'head.weight' in k:
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese
elif 'blocks.' in k:
block_id = int(k.split('.')[1])
if 'receptance.weight' in k:
ww[k] *= 0 # 0 works the best
ww[k] *= 0 # init with zero matrix
elif 'attn.key.weight' in k:
ww[k] *= 0 # 0 works the best
ww[k] *= 0 # init with zero matrix
elif 'attn.output.weight' in k:
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
elif 'mlp.weight.weight' in k:
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
elif 'attn.time_w' in k:
if self.config.n_head > 1: # different time_w for different head
curve = torch.tensor([0.9 ** (self.config.ctx_len - 1 - i) for i in range(self.config.ctx_len)])
curve = curve * 2 + 0.7
for h in range(self.config.n_head):
curve = torch.tensor([i for i in range(self.config.ctx_len)]) / (self.config.ctx_len - 1)
curve = torch.pow(curve, 24) # concentrated effect
curve = curve - torch.mean(curve) + 1 # normalize mean to 1
if self.config.n_head > 1:
mix_strength = 1 - 1.2 * h / (self.config.n_head - 1) # mix_strength from 1 to -0.2
else:
mix_strength = 0.5
ww[k][h] = (1 - mix_strength) + curve * mix_strength
# special tweaks because of time_shift
ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 2] * 2 + 1) / 3
ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] + 1) / 2
ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 3] * 2 + 1) / 3
ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] * 1 + 2) / 3
ww[k][h][self.config.ctx_len - 1] = 1
# print(k, h, mix_strength, ww[k][h])
@ -421,7 +425,7 @@ class GPT(nn.Module):
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
return optimizer
def forward(self, idx, targets=None):
@ -433,6 +437,7 @@ class GPT(nn.Module):
x = self.blocks(x)
x = self.ln_f(x)
x = x * self.time_out[:, :T, :] # reduce confidence of early tokens
x = self.head(x)
loss = None

@ -15,7 +15,8 @@ class TrainerConfig:
max_epochs = 10
batch_size = 64
learning_rate = 4e-4
betas = (0.9, 0.95)
betas = (0.9, 0.99)
eps = 1e-8
grad_norm_clip = 1.0
weight_decay = 0.01
lr_decay = False # linear warmup followed by cosine decay

@ -38,10 +38,11 @@ n_embd = n_head * 64
batch_size = 64
n_epoch = 50 # the 'epoch' here is actually very short (and of fixed length)
lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # seems RWKV can use higher lr
lr_init = 8e-4 if model_type == 'RWKV' else 4e-4 # seems RWKV can use higher lr
lr_final = 2e-4
betas = (0.9, 0.99)
betas = (0.9, 0.999) if model_type == 'RWKV' else (0.9, 0.99)
eps = 1e-8
weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when we have enough data
epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
@ -91,9 +92,9 @@ train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(),
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
n_layer=n_layer, n_head=n_head, n_embd=n_embd))
print('model', model_type, 'total epoch', n_epoch, 'batch_size', batch_size, 'n_layer', n_layer, 'n_head', n_head, 'n_embd', n_embd, 'len', ctx_len)
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'ctx', ctx_len)
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps,
warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0)
trainer = Trainer(model, train_dataset, None, tconf)

Loading…
Cancel
Save