diff --git a/src/model.py b/src/model.py index 328a560..3fdaddb 100644 --- a/src/model.py +++ b/src/model.py @@ -33,11 +33,11 @@ class RWKV_TimeMix(nn.Module): self.key = nn.Linear(config.n_embd, config.n_embd) self.value = nn.Linear(config.n_embd, config.n_embd) self.receptance = nn.Linear(config.n_embd, config.n_embd) - + self.output = nn.Linear(config.n_embd, config.n_embd) def forward(self, x): - B, T, C = x.size() + B, T, C = x.size() TT = self.ctx_len w = F.pad(self.time_w, (0, TT)) w = torch.tile(w, [TT]) @@ -51,7 +51,7 @@ class RWKV_TimeMix(nn.Module): v = self.value(x) r = self.receptance(x) - k = torch.clamp(k, max=30) # clamp crazy values + k = torch.clamp(k, max=30) # clamp extreme values k = torch.exp(k) sum_k = torch.cumsum(k, dim=1) @@ -154,7 +154,7 @@ class MHA_rotary(nn.Module): k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) - q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:] + q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:] k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:] cos, sin = self.rotary_emb(q, seq_len=T) q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding @@ -163,7 +163,7 @@ class MHA_rotary(nn.Module): att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask - att = F.softmax(att, dim = -1) # softmax + att = F.softmax(att, dim = -1) # softmax x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) @@ -196,7 +196,7 @@ class GeGLU(torch.nn.Module): ######################################################################################################## # MHA_pro: with more tricks -######################################################################################################## +######################################################################################################## class MHA_pro(nn.Module): def __init__(self, config, layer_id): @@ -211,7 +211,7 @@ class MHA_pro(nn.Module): self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len)) self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1)) self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1)) - self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) + self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) self.time_shift = nn.ZeroPad2d((0,0,1,0)) self.query = nn.Linear(config.n_embd, config.n_embd) @@ -239,7 +239,7 @@ class MHA_pro(nn.Module): k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) - q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:] + q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:] k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:] cos, sin = self.rotary_emb(q, seq_len=T) q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding @@ -283,7 +283,7 @@ class FixedNorm(nn.Module): x_normed = x / (norm_x * self.dd + 1e-12) return x_normed -######################################################################################################## +######################################################################################################## class GPTConfig: def __init__(self, vocab_size, ctx_len, **kwargs): @@ -300,6 +300,8 @@ class Block(nn.Module): self.ln2 = nn.LayerNorm(config.n_embd) if config.model_type == 'RWKV': + # self.ln1 = FixedNorm(config.n_embd) + # self.ln2 = FixedNorm(config.n_embd) self.attn = RWKV_TimeMix(config, layer_id) self.mlp = RWKV_ChannelMix(config, layer_id) @@ -332,6 +334,7 @@ class GPT(nn.Module): self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd) + self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.ctx_len = config.ctx_len @@ -345,32 +348,33 @@ class GPT(nn.Module): ww[k] *= math.sqrt(self.config.vocab_size) else: ww[k] *= math.sqrt(self.config.n_embd) - ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese + ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese elif 'head.weight' in k: - ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese + ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese elif 'blocks.' in k: block_id = int(k.split('.')[1]) if 'receptance.weight' in k: - ww[k] *= 0 # 0 works the best + ww[k] *= 0 # init with zero matrix elif 'attn.key.weight' in k: - ww[k] *= 0 # 0 works the best + ww[k] *= 0 # init with zero matrix elif 'attn.output.weight' in k: ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results elif 'mlp.weight.weight' in k: ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results elif 'attn.time_w' in k: - if self.config.n_head > 1: # different time_w for different head - for h in range(self.config.n_head): - curve = torch.tensor([i for i in range(self.config.ctx_len)]) / (self.config.ctx_len - 1) - curve = torch.pow(curve, 24) # concentrated effect - curve = curve - torch.mean(curve) + 1 # normalize mean to 1 + curve = torch.tensor([0.9 ** (self.config.ctx_len - 1 - i) for i in range(self.config.ctx_len)]) + curve = curve * 2 + 0.7 + for h in range(self.config.n_head): + if self.config.n_head > 1: mix_strength = 1 - 1.2 * h / (self.config.n_head - 1) # mix_strength from 1 to -0.2 - ww[k][h] = (1 - mix_strength) + curve * mix_strength - # special tweaks because of time_shift - ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 2] * 2 + 1) / 3 - ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] + 1) / 2 - ww[k][h][self.config.ctx_len - 1] = 1 - # print(k, h, mix_strength, ww[k][h]) + else: + mix_strength = 0.5 + ww[k][h] = (1 - mix_strength) + curve * mix_strength + # special tweaks because of time_shift + ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 3] * 2 + 1) / 3 + ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] * 1 + 2) / 3 + ww[k][h][self.config.ctx_len - 1] = 1 + # print(k, h, mix_strength, ww[k][h]) logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) @@ -421,7 +425,7 @@ class GPT(nn.Module): {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] - optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) + optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps) return optimizer def forward(self, idx, targets=None): @@ -433,6 +437,7 @@ class GPT(nn.Module): x = self.blocks(x) x = self.ln_f(x) + x = x * self.time_out[:, :T, :] # reduce confidence of early tokens x = self.head(x) loss = None diff --git a/src/trainer.py b/src/trainer.py index ead5bb2..ec1ff66 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -15,7 +15,8 @@ class TrainerConfig: max_epochs = 10 batch_size = 64 learning_rate = 4e-4 - betas = (0.9, 0.95) + betas = (0.9, 0.99) + eps = 1e-8 grad_norm_clip = 1.0 weight_decay = 0.01 lr_decay = False # linear warmup followed by cosine decay diff --git a/train.py b/train.py index 2e42008..188b438 100644 --- a/train.py +++ b/train.py @@ -38,10 +38,11 @@ n_embd = n_head * 64 batch_size = 64 n_epoch = 50 # the 'epoch' here is actually very short (and of fixed length) -lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # seems RWKV can use higher lr +lr_init = 8e-4 if model_type == 'RWKV' else 4e-4 # seems RWKV can use higher lr lr_final = 2e-4 -betas = (0.9, 0.99) +betas = (0.9, 0.999) if model_type == 'RWKV' else (0.9, 0.99) +eps = 1e-8 weight_decay = 0 if model_type == 'RWKV' else 0.01 # seems wd is not very useful when we have enough data epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress @@ -91,9 +92,9 @@ train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type, n_layer=n_layer, n_head=n_head, n_embd=n_embd)) -print('model', model_type, 'total epoch', n_epoch, 'batch_size', batch_size, 'n_layer', n_layer, 'n_head', n_head, 'n_embd', n_embd, 'len', ctx_len) +print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'ctx', ctx_len) tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay, - learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, + learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0) trainer = Trainer(model, train_dataset, None, tconf)