From fd098b1d2e8702a2944ea6297bd11a146f31e0c2 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Fri, 13 Aug 2021 02:25:43 +0800 Subject: [PATCH] small update --- src/model.py | 30 ++++++++++++------------------ src/utils.py | 6 +++++- train.py | 12 +++++++----- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/model.py b/src/model.py index 03021e7..729f4fd 100644 --- a/src/model.py +++ b/src/model.py @@ -58,10 +58,9 @@ class RWKV_TimeMix(nn.Module): v = v.view(B, T, self.n_head, self.head_size) wkv = (torch.einsum('htu,buhc->bthc', w, k * v)).contiguous().view(B, T, C) - y = torch.sigmoid(r) * wkv / sum_k + rwkv = torch.sigmoid(r) * wkv / sum_k - y = self.output(y) * self.time_gamma[:T, :] - return y + return self.output(rwkv) * self.time_gamma[:T, :] class RWKV_ChannelMix(nn.Module): def __init__(self, config, layer_id): @@ -82,10 +81,10 @@ class RWKV_ChannelMix(nn.Module): v = self.value(x) r = self.receptance(x) - wkv = self.weight(F.mish(k) * v) # mish is a bit better than gelu - y = torch.sigmoid(r) * wkv + wkv = self.weight(F.mish(k) * v) # seems mish is a bit better than gelu + rwkv = torch.sigmoid(r) * wkv - return y + return rwkv ######################################################################################################## # MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN @@ -335,11 +334,6 @@ class GPT(nn.Module): self.apply(self._init_weights) if self.config.model_type == 'RWKV': # improve orthogonal weight init - - token_diversity = pow(self.config.vocab_size / 200, 1/3) - token_diversity = 0.4 * min(max(token_diversity, 1), 2) # 200 -> 0.4, 1600 -> 0.8. ENG-char 0.4 CHN-char 0.8 - print('token_diversity', token_diversity) - ww = self.state_dict() for k in ww: if 'tok_emb' in k: @@ -347,19 +341,19 @@ class GPT(nn.Module): ww[k] *= math.sqrt(self.config.vocab_size) else: ww[k] *= math.sqrt(self.config.n_embd) - ww[k] *= token_diversity + ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese elif 'head.weight' in k: - ww[k] *= token_diversity + ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese elif 'blocks.' in k: block_id = int(k.split('.')[1]) if 'receptance.weight' in k: - ww[k] *= 0.2 # 0.2 ~ 0.5 + ww[k] *= 0.2 # 0.2 ~ 0.5 gives similar results elif 'attn.key.weight' in k: - ww[k] *= 0.2 # 0.2 ~ 0.5 + ww[k] *= 0.2 # 0.2 ~ 0.5 gives similar results elif 'attn.output.weight' in k: - ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 + ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results elif 'mlp.weight.weight' in k: - ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 + ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) @@ -426,6 +420,6 @@ class GPT(nn.Module): loss = None if targets is not None: - loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(x.view(-1, x.size(-1)), targets.view(-1)) + loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(x.view(-1, x.size(-1)), targets.view(-1)) # try increasing smoothing if you see nan return x, loss diff --git a/src/utils.py b/src/utils.py index 6192589..5f9bb65 100644 --- a/src/utils.py +++ b/src/utils.py @@ -27,16 +27,20 @@ def top_p_probs(probs, p): def sample_logits(logits, pos, temperature=1.0, top_k=None, top_p=None, min_p_pow=None, min_p_ratio=None): logits = logits[:, pos, :] / temperature probs = F.softmax(logits, dim=-1) + if min_p_ratio is not None: limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio logits[probs < limit] = -float('Inf') + if top_k is not None: logits = top_k_logits(logits, top_k) + probs = F.softmax(logits, dim=-1) + if top_p is not None: probs[0] = top_p_probs(probs[0], top_p) + ix = torch.multinomial(probs, num_samples=1) - return ix[0][0].cpu() def set_seed(seed): diff --git a/train.py b/train.py index 6ba94e1..a5c01e8 100644 --- a/train.py +++ b/train.py @@ -20,6 +20,8 @@ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s # MHA_pro - slow (lots of tricks) - VRAM hungry - good performance model_type = 'RWKV' # 'RWKV' or 'MHA_rotary' or 'MHA_pro' +# datafile = u"V:\\NLP\\text8" +# datafile = u"V:\\NLP\\enwik8" datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip datafile_encoding = 'utf-8' # datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt" @@ -27,15 +29,15 @@ datafile_encoding = 'utf-8' model_level = 'character' # 'character' or 'word' -ctx_len = 256 # length of ctx window +ctx_len = 256 # context length n_layer = 5 n_head = 8 n_embd = n_head * 64 batch_size = 64 -n_epoch = 50 # the 'epoch' here is very short -lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr +n_epoch = 50 # the 'epoch' here is actually very short (and of fixed length) +lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # seems RWKV can use higher lr lr_final = 2e-4 betas = (0.9, 0.99) @@ -72,7 +74,7 @@ class Dataset(Dataset): return epoch_length_fixed def __getitem__(self, idx): - i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) # CHEAT: pick a spot in the dataset at random + i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) # cheat: pick a random spot in dataset chunk = self.data[i:i+self.ctx_len+1] dix = [self.stoi[s] for s in chunk] x = torch.tensor(dix[:-1], dtype=torch.long) @@ -108,7 +110,7 @@ NUM_OF_RUNS = 5 LENGTH_OF_EACH = 300 for run in range(NUM_OF_RUNS): - context = "It was" + context = "it was" if model_level == 'word': x = np.array([train_dataset.stoi[s] for s in context.strip().lower().split(' ')], dtype=np.int64)