main
BlinkDL 3 years ago
parent c7b1900270
commit e2ec7ae023

@ -32,7 +32,7 @@ class MyDataset(Dataset):
self.data_size = len(self.data._bin_buffer) // 2
rank_zero_info(f"Data has {self.data_size} tokens.")
if args.my_qa_mask == 1:
if args.my_qa_mask > 0:
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
self.data_pile_size = len(self.data_pile._bin_buffer) // 2
@ -156,7 +156,7 @@ class MyDataset(Dataset):
if args.my_pile_stage > 0:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
if args.my_qa_mask == 1:
if args.my_qa_mask > 0:
ii_orig = ii
if ii % 2 == 0:
ii = (ii // 2) * args.magic_prime

@ -108,11 +108,13 @@ class RWKV_TimeMix(MyModule):
self.layer_id = layer_id
self.ctx_len = args.ctx_len
self.n_embd = args.n_embd
self.my_testing = self.args.my_testing
with torch.no_grad(): # fancy init
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
# fancy time_decay
decay_speed = torch.ones(args.dim_att)
@ -126,12 +128,9 @@ class RWKV_TimeMix(MyModule):
self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)
# fancy time_mix
x = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
x[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
@ -147,24 +146,17 @@ class RWKV_TimeMix(MyModule):
self.vv = nn.Linear(args.n_embd, d_qkv, bias=False)
self.oo = nn.Linear(d_qkv, args.n_embd, bias=False)
with torch.no_grad():
x = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
x[0, 0, i] = i / args.n_embd
self.time_mix_qq = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_kk = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_vv = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
if 'a' not in os.environ["RWKV_MY_TESTING"]:
@MyFunction
def jit_func(self, x):
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x)
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
@ -188,25 +180,20 @@ class RWKV_TimeMix(MyModule):
@MyFunction
def jit_funcQKV(self, x):
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x)
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq)
xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk)
xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)
qq = self.qq(xqq)
kk = self.kk(xkk)
vv = self.vv(xvv)
return sr, k, v, qq, kk, vv
def forward(self, x):
@ -223,18 +210,15 @@ class RWKV_ChannelMix(MyModule):
super().__init__()
self.args = args
self.layer_id = layer_id
self.my_testing = self.args.my_testing
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
x = torch.ones(1, 1, args.n_embd)
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
x[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
ddd[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
@ -255,7 +239,6 @@ class MishGLU(MyModule):
super().__init__()
self.args = args
self.layer_id = layer_id
self.my_testing = self.args.my_testing
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad():
@ -478,7 +461,7 @@ class RWKV(pl.LightningModule):
def training_step(self, batch, batch_idx):
args = self.args
if args.my_qa_mask == 0:
if args.my_qa_mask != 1:
idx, targets = batch
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

@ -154,14 +154,32 @@ def generate_init_weight(model, init_weight_name):
mm = model.generate_init_weight()
if model.args.my_pile_stage == 1:
try:
if len(model.args.load_model) > 0:
print(f"Combine weights from {model.args.load_model}...")
load_dict = torch.load(model.args.load_model, map_location="cpu")
for k in load_dict:
assert k in mm
mm[k] = load_dict[k].reshape(mm[k].shape)
except:
print(f"\n\n!!! FAIL !!!\n\n")
src = load_dict[k]
try:
mm[k] = src.reshape(mm[k].shape)
except:
tmp = mm[k].squeeze().clone()
print(k, src.shape, '-->', mm[k].shape)
ss = src.shape[0]
dd = tmp.shape[0]
for i in range(dd):
pos = i / dd * ss
if pos >= ss - 1:
tmp[i] = src[ss-1]
else:
p0 = int(math.floor(pos))
ii = pos - p0
tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii)
mm[k] = tmp.reshape(mm[k].shape)
sss = src.squeeze().float().cpu().numpy()
print(sss[:10], '...', sss[-10:])
mmm = mm[k].squeeze().float().cpu().numpy()
print(mmm[:10], '...', mmm[-10:])
print(f"Save to {init_weight_name}...")
torch.save(mm, init_weight_name)

Loading…
Cancel
Save