small update

main
BlinkDL 4 years ago
parent 3b01c8c3cf
commit fd098b1d2e

@ -58,10 +58,9 @@ class RWKV_TimeMix(nn.Module):
v = v.view(B, T, self.n_head, self.head_size) v = v.view(B, T, self.n_head, self.head_size)
wkv = (torch.einsum('htu,buhc->bthc', w, k * v)).contiguous().view(B, T, C) wkv = (torch.einsum('htu,buhc->bthc', w, k * v)).contiguous().view(B, T, C)
y = torch.sigmoid(r) * wkv / sum_k rwkv = torch.sigmoid(r) * wkv / sum_k
y = self.output(y) * self.time_gamma[:T, :] return self.output(rwkv) * self.time_gamma[:T, :]
return y
class RWKV_ChannelMix(nn.Module): class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id): def __init__(self, config, layer_id):
@ -82,10 +81,10 @@ class RWKV_ChannelMix(nn.Module):
v = self.value(x) v = self.value(x)
r = self.receptance(x) r = self.receptance(x)
wkv = self.weight(F.mish(k) * v) # mish is a bit better than gelu wkv = self.weight(F.mish(k) * v) # seems mish is a bit better than gelu
y = torch.sigmoid(r) * wkv rwkv = torch.sigmoid(r) * wkv
return y return rwkv
######################################################################################################## ########################################################################################################
# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN # MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
@ -335,11 +334,6 @@ class GPT(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
if self.config.model_type == 'RWKV': # improve orthogonal weight init if self.config.model_type == 'RWKV': # improve orthogonal weight init
token_diversity = pow(self.config.vocab_size / 200, 1/3)
token_diversity = 0.4 * min(max(token_diversity, 1), 2) # 200 -> 0.4, 1600 -> 0.8. ENG-char 0.4 CHN-char 0.8
print('token_diversity', token_diversity)
ww = self.state_dict() ww = self.state_dict()
for k in ww: for k in ww:
if 'tok_emb' in k: if 'tok_emb' in k:
@ -347,19 +341,19 @@ class GPT(nn.Module):
ww[k] *= math.sqrt(self.config.vocab_size) ww[k] *= math.sqrt(self.config.vocab_size)
else: else:
ww[k] *= math.sqrt(self.config.n_embd) ww[k] *= math.sqrt(self.config.n_embd)
ww[k] *= token_diversity ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese
elif 'head.weight' in k: elif 'head.weight' in k:
ww[k] *= token_diversity ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might works better for chinese
elif 'blocks.' in k: elif 'blocks.' in k:
block_id = int(k.split('.')[1]) block_id = int(k.split('.')[1])
if 'receptance.weight' in k: if 'receptance.weight' in k:
ww[k] *= 0.2 # 0.2 ~ 0.5 ww[k] *= 0.2 # 0.2 ~ 0.5 gives similar results
elif 'attn.key.weight' in k: elif 'attn.key.weight' in k:
ww[k] *= 0.2 # 0.2 ~ 0.5 ww[k] *= 0.2 # 0.2 ~ 0.5 gives similar results
elif 'attn.output.weight' in k: elif 'attn.output.weight' in k:
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
elif 'mlp.weight.weight' in k: elif 'mlp.weight.weight' in k:
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
@ -426,6 +420,6 @@ class GPT(nn.Module):
loss = None loss = None
if targets is not None: if targets is not None:
loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(x.view(-1, x.size(-1)), targets.view(-1)) loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(x.view(-1, x.size(-1)), targets.view(-1)) # try increasing smoothing if you see nan
return x, loss return x, loss

@ -27,16 +27,20 @@ def top_p_probs(probs, p):
def sample_logits(logits, pos, temperature=1.0, top_k=None, top_p=None, min_p_pow=None, min_p_ratio=None): def sample_logits(logits, pos, temperature=1.0, top_k=None, top_p=None, min_p_pow=None, min_p_ratio=None):
logits = logits[:, pos, :] / temperature logits = logits[:, pos, :] / temperature
probs = F.softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1)
if min_p_ratio is not None: if min_p_ratio is not None:
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
logits[probs < limit] = -float('Inf') logits[probs < limit] = -float('Inf')
if top_k is not None: if top_k is not None:
logits = top_k_logits(logits, top_k) logits = top_k_logits(logits, top_k)
probs = F.softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1)
if top_p is not None: if top_p is not None:
probs[0] = top_p_probs(probs[0], top_p) probs[0] = top_p_probs(probs[0], top_p)
ix = torch.multinomial(probs, num_samples=1)
ix = torch.multinomial(probs, num_samples=1)
return ix[0][0].cpu() return ix[0][0].cpu()
def set_seed(seed): def set_seed(seed):

@ -20,6 +20,8 @@ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s
# MHA_pro - slow (lots of tricks) - VRAM hungry - good performance # MHA_pro - slow (lots of tricks) - VRAM hungry - good performance
model_type = 'RWKV' # 'RWKV' or 'MHA_rotary' or 'MHA_pro' model_type = 'RWKV' # 'RWKV' or 'MHA_rotary' or 'MHA_pro'
# datafile = u"V:\\NLP\\text8"
# datafile = u"V:\\NLP\\enwik8"
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip
datafile_encoding = 'utf-8' datafile_encoding = 'utf-8'
# datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt" # datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt"
@ -27,15 +29,15 @@ datafile_encoding = 'utf-8'
model_level = 'character' # 'character' or 'word' model_level = 'character' # 'character' or 'word'
ctx_len = 256 # length of ctx window ctx_len = 256 # context length
n_layer = 5 n_layer = 5
n_head = 8 n_head = 8
n_embd = n_head * 64 n_embd = n_head * 64
batch_size = 64 batch_size = 64
n_epoch = 50 # the 'epoch' here is very short n_epoch = 50 # the 'epoch' here is actually very short (and of fixed length)
lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # seems RWKV can use higher lr
lr_final = 2e-4 lr_final = 2e-4
betas = (0.9, 0.99) betas = (0.9, 0.99)
@ -72,7 +74,7 @@ class Dataset(Dataset):
return epoch_length_fixed return epoch_length_fixed
def __getitem__(self, idx): def __getitem__(self, idx):
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) # CHEAT: pick a spot in the dataset at random i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) # cheat: pick a random spot in dataset
chunk = self.data[i:i+self.ctx_len+1] chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk] dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long) x = torch.tensor(dix[:-1], dtype=torch.long)
@ -108,7 +110,7 @@ NUM_OF_RUNS = 5
LENGTH_OF_EACH = 300 LENGTH_OF_EACH = 300
for run in range(NUM_OF_RUNS): for run in range(NUM_OF_RUNS):
context = "It was" context = "it was"
if model_level == 'word': if model_level == 'word':
x = np.array([train_dataset.stoi[s] for s in context.strip().lower().split(' ')], dtype=np.int64) x = np.array([train_dataset.stoi[s] for s in context.strip().lower().split(' ')], dtype=np.int64)

Loading…
Cancel
Save