main
BlinkDL 3 years ago
parent bc47cb9f1a
commit b2a240d73d

@ -32,6 +32,10 @@ class MyDataset(Dataset):
self.data_size = len(self.data._bin_buffer) // 2
print(f"Data has {self.data_size} tokens.")
if args.my_qa_mask == 1:
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
self.data_pile_size = len(self.data_pile._bin_buffer) // 2
if args.my_pile_stage > 0:
# assert self.data_size == 332115325534 and self.vocab_size == 50277
self.samples_per_epoch = args.epoch_steps * args.real_bsz
@ -146,25 +150,69 @@ class MyDataset(Dataset):
else:
ctx_len = args.ctx_len
req_len = ctx_len + 1
magic_prime = args.magic_prime
data = self.data
if args.my_pile_stage > 0:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
if args.my_qa_mask == 1:
ii_orig = ii
if ii % 2 == 0:
ii = (ii // 2) * args.magic_prime
magic_prime = 324331313
data = self.data_pile
else:
ii = ii // 2
factor = (math.sqrt(5) - 1) / 2
factor = int(args.magic_prime * factor)
i = ((factor * ii * ii * ii) % args.magic_prime) * ctx_len
i = i + args.my_pile_shift
factor = int(magic_prime * factor)
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
if (args.my_qa_mask == 0) or (data == self.data_pile):
i = i + args.my_pile_shift
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
else:
# cheat: pick a random spot in dataset
i = np.random.randint(0, self.data_size - req_len)
if args.data_type == "binidx":
dix = self.data.get(idx=0, offset=i, length=req_len).astype(int)
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
elif args.data_type == "numpy":
dix = self.data[i : i + req_len]
dix = data[i : i + req_len]
else:
dix = [self.stoi[s] for s in self.data[i : i + req_len]]
dix = [self.stoi[s] for s in data[i : i + req_len]]
if args.my_qa_mask == 1:
if data == self.data_pile:
z = [1] * ctx_len
else:
z = [0] * ctx_len
z_sum = 0
isGood = False
for i in range(3, ctx_len):
if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
isGood = True
if dix[i] == 0:
isGood = False
if isGood:
z[i] = 1
z_sum += 1
if z_sum == 0:
z = [1] * ctx_len
i = np.random.randint(0, self.data_pile_size - req_len)
dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
z = torch.tensor(z, dtype=torch.bfloat16)
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
# if ii_orig < 50:
# # if rank == 1:
# print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
# else:
# exit(0)
if args.my_qa_mask == 1:
return x, y, z
return x, y

@ -108,7 +108,7 @@ class RWKV_TimeMix(MyModule):
self.layer_id = layer_id
self.ctx_len = args.ctx_len
self.n_embd = args.n_embd
self.my_testing = self.args.my_testing
attn_sz = args.n_embd
with torch.no_grad(): # fancy init
@ -142,6 +142,9 @@ class RWKV_TimeMix(MyModule):
self.output = nn.Linear(attn_sz, args.n_embd, bias=False)
# if self.my_testing > 0:
# self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd))
@MyFunction
def jit_func(self, x):
@ -174,7 +177,7 @@ class RWKV_ChannelMix(MyModule):
super().__init__()
self.args = args
self.layer_id = layer_id
self.my_testing = self.args.my_testing
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix
@ -192,6 +195,12 @@ class RWKV_ChannelMix(MyModule):
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, args.n_embd, bias=False)
# if self.my_testing in [1]:
# self.aaa = nn.Parameter(torch.zeros(1, 1, hidden_sz))
# elif self.my_testing in [2]:
# self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd))
@MyFunction
def forward(self, x):
xx = self.time_shift(x)
@ -205,6 +214,19 @@ class RWKV_ChannelMix(MyModule):
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
# k = self.key(xk)
# # if self.my_testing in [0, 2]:
# k = torch.square(torch.relu(k))
# # elif self.my_testing == 1:
# # k = torch.square(torch.relu(k)) + k * self.aaa
# kv = self.value(k)
# r = self.receptance(xr)
# # if self.my_testing == 0:
# r = torch.sigmoid(r)
# # elif self.my_testing == 2:
# # r = torch.sigmoid(r) + r * self.aaa
# rkv = r * kv
# return rkv
########################################################################################################
# The RWKV Model with our blocks
@ -401,9 +423,38 @@ class RWKV(pl.LightningModule):
def training_step(self, batch, batch_idx):
args = self.args
idx, targets = batch
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
if args.my_qa_mask == 0:
idx, targets = batch
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
else:
idx, targets, mask = batch
mask = mask.view(-1)
sum_mask = torch.sum(mask).item()
# if sum_mask == 0:
# return torch.tensor([0.0], requires_grad=True)
logits = self(idx)
if sum_mask == mask.shape[0]:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
# print('rank', self.global_rank, 'loss', loss.item())
else:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
# loss_raw = loss
loss = torch.sum(loss * mask) / sum_mask
# torch.set_printoptions(threshold=10000)
# if True: #self.global_rank == 1:
# tmp = ''
# sss = 0
# ccc = 0
# for i in range(mask.shape[0]):
# if mask[i] > 0:
# tmp += str(idx.view(-1)[i].item()) + ','
# sss += loss_raw.view(-1)[i].float().item()
# ccc += 1
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
return L2Wrap.apply(loss, logits)
def training_step_end(self, batch_parts):

@ -98,7 +98,7 @@ class train_callback(pl.Callback):
lll["kt/s"] = kt_s
trainer.my_wandb.log(lll, step=int(real_step))
if args.magic_prime > 0:
if int(real_step) == int(args.magic_prime // args.real_bsz) - 1:
if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1:
to_save_dict = pl_module.state_dict()
torch.save(
to_save_dict,

@ -100,6 +100,7 @@ if __name__ == "__main__":
parser.add_argument("--my_pos_emb", default=0, type=int)
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default=0, type=int)
parser = Trainer.add_argparse_args(parser)

Loading…
Cancel
Save