misc improvements

main
BlinkDL 4 years ago
parent ef29f4b9e8
commit c675b47705

@ -33,11 +33,11 @@ class RWKV_TimeMix(nn.Module):
self.key = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
self.receptance = nn.Linear(config.n_embd, config.n_embd)
self.output = nn.Linear(config.n_embd, config.n_embd)
def forward(self, x):
B, T, C = x.size()
B, T, C = x.size()
TT = self.ctx_len
w = F.pad(self.time_w, (0, TT))
w = torch.tile(w, [TT])
@ -51,7 +51,7 @@ class RWKV_TimeMix(nn.Module):
v = self.value(x)
r = self.receptance(x)
k = torch.clamp(k, max=30) # clamp crazy values
k = torch.clamp(k, max=30) # clamp extreme values
k = torch.exp(k)
sum_k = torch.cumsum(k, dim=1)
@ -154,7 +154,7 @@ class MHA_rotary(nn.Module):
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
cos, sin = self.rotary_emb(q, seq_len=T)
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
@ -163,7 +163,7 @@ class MHA_rotary(nn.Module):
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
att = F.softmax(att, dim = -1) # softmax
att = F.softmax(att, dim = -1) # softmax
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
@ -196,7 +196,7 @@ class GeGLU(torch.nn.Module):
########################################################################################################
# MHA_pro: with more tricks
########################################################################################################
########################################################################################################
class MHA_pro(nn.Module):
def __init__(self, config, layer_id):
@ -211,7 +211,7 @@ class MHA_pro(nn.Module):
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.time_shift = nn.ZeroPad2d((0,0,1,0))
self.query = nn.Linear(config.n_embd, config.n_embd)
@ -239,7 +239,7 @@ class MHA_pro(nn.Module):
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
cos, sin = self.rotary_emb(q, seq_len=T)
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
@ -283,7 +283,7 @@ class FixedNorm(nn.Module):
x_normed = x / (norm_x * self.dd + 1e-12)
return x_normed
########################################################################################################
########################################################################################################
class GPTConfig:
def __init__(self, vocab_size, ctx_len, **kwargs):
@ -300,6 +300,8 @@ class Block(nn.Module):
self.ln2 = nn.LayerNorm(config.n_embd)
if config.model_type == 'RWKV':
# self.ln1 = FixedNorm(config.n_embd)
# self.ln2 = FixedNorm(config.n_embd)
self.attn = RWKV_TimeMix(config, layer_id)
self.mlp = RWKV_ChannelMix(config, layer_id)
@ -332,6 +334,7 @@ class GPT(nn.Module):
self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.ctx_len = config.ctx_len
@ -345,32 +348,33 @@ class GPT(nn.Module):
ww[k] *= math.sqrt(self.config.vocab_size)
else:
ww[k] *= math.sqrt(self.config.n_embd)
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese
elif 'head.weight' in k:
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese
elif 'blocks.' in k:
block_id = int(k.split('.')[1])
if 'receptance.weight' in k:
ww[k] *= 0 # 0 works the best
ww[k] *= 0 # init with zero matrix
elif 'attn.key.weight' in k:
ww[k] *= 0 # 0 works the best
ww[k] *= 0 # init with zero matrix
elif 'attn.output.weight' in k:
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
elif 'mlp.weight.weight' in k:
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
elif 'attn.time_w' in k:
if self.config.n_head > 1: # different time_w for different head
for h in range(self.config.n_head):
curve = torch.tensor([i for i in range(self.config.ctx_len)]) / (self.config.ctx_len - 1)
curve = torch.pow(curve, 24) # concentrated effect
curve = curve - torch.mean(curve) + 1 # normalize mean to 1
curve = torch.tensor([0.9 ** (self.config.ctx_len - 1 - i) for i in range(self.config.ctx_len)])
curve = curve * 2 + 0.7
for h in range(self.config.n_head):
if self.config.n_head > 1:
mix_strength = 1 - 1.2 * h / (self.config.n_head - 1) # mix_strength from 1 to -0.2
ww[k][h] = (1 - mix_strength) + curve * mix_strength
# special tweaks because of time_shift
ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 2] * 2 + 1) / 3
ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] + 1) / 2
ww[k][h][self.config.ctx_len - 1] = 1
# print(k, h, mix_strength, ww[k][h])
else:
mix_strength = 0.5
ww[k][h] = (1 - mix_strength) + curve * mix_strength
# special tweaks because of time_shift
ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 3] * 2 + 1) / 3
ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] * 1 + 2) / 3
ww[k][h][self.config.ctx_len - 1] = 1
# print(k, h, mix_strength, ww[k][h])
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
@ -421,7 +425,7 @@ class GPT(nn.Module):
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
return optimizer
def forward(self, idx, targets=None):
@ -433,6 +437,7 @@ class GPT(nn.Module):
x = self.blocks(x)
x = self.ln_f(x)
x = x * self.time_out[:, :T, :] # reduce confidence of early tokens
x = self.head(x)
loss = None

@ -15,7 +15,8 @@ class TrainerConfig:
max_epochs = 10
batch_size = 64
learning_rate = 4e-4
betas = (0.9, 0.95)
betas = (0.9, 0.99)
eps = 1e-8
grad_norm_clip = 1.0
weight_decay = 0.01
lr_decay = False # linear warmup followed by cosine decay

@ -38,10 +38,11 @@ n_embd = n_head * 64
batch_size = 64
n_epoch = 50 # the 'epoch' here is actually very short (and of fixed length)
lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # seems RWKV can use higher lr
lr_init = 8e-4 if model_type == 'RWKV' else 4e-4 # seems RWKV can use higher lr
lr_final = 2e-4
betas = (0.9, 0.99)
betas = (0.9, 0.999) if model_type == 'RWKV' else (0.9, 0.99)
eps = 1e-8
weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when we have enough data
epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
@ -91,9 +92,9 @@ train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(),
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
n_layer=n_layer, n_head=n_head, n_embd=n_embd))
print('model', model_type, 'total epoch', n_epoch, 'batch_size', batch_size, 'n_layer', n_layer, 'n_head', n_head, 'n_embd', n_embd, 'len', ctx_len)
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'ctx', ctx_len)
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps,
warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0)
trainer = Trainer(model, train_dataset, None, tconf)

Loading…
Cancel
Save