improvements

main
BlinkDL 3 years ago
parent f81349f127
commit 778c0a7f58

@ -2,12 +2,13 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import json
import json, math
import numpy as np
import torch
from torch.utils.data import Dataset
from pytorch_lightning.utilities import rank_zero_info
from .binidx import MMapIndexedDataset
from .utils import MaybeIsPrime
class MyDataset(Dataset):
@ -20,6 +21,18 @@ class MyDataset(Dataset):
print("current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data._bin_buffer) // 2
print(f"data has {self.data_size} tokens.")
if args.my_pile_mode > 0:
assert self.data_size == 332115325534 and self.vocab_size == 50277 and args.ctx_len == 1024
self.samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
assert self.samples_per_epoch == 40320
print("########## Pile 20b-tokenized mode {args.my_pile_mode} ##########")
self.magic_prime = 324331313
dataset_slot = self.data_size // args.ctx_len
assert MaybeIsPrime(self.magic_prime)
assert self.magic_prime % 3 == 2
assert self.magic_prime / dataset_slot > 0.999999 and self.magic_prime / dataset_slot <= 1
elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size
@ -48,15 +61,29 @@ class MyDataset(Dataset):
self.itos = {i: ch for i, ch in enumerate(unique)}
def __len__(self):
return self.args.epoch_steps * int(self.args.devices) * self.args.micro_bsz
return self.args.epoch_steps * self.args.micro_bsz
def __getitem__(self, idx):
#
# we are cheating: pick a random spot in dataset
#
rank = self.global_rank
epoch = self.real_epoch
world_size = self.world_size
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
ctx_len = self.args.ctx_len
req_len = ctx_len + 1
i = np.random.randint(0, self.data_size - req_len)
if self.args.my_pile_mode > 0:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
factor = (math.sqrt(5) - 1) / 2
factor = int(self.magic_prime * factor)
i = ((factor * ii * ii * ii) % self.magic_prime) * ctx_len
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
else:
i = np.random.randint(0, self.data_size - req_len)
if "MMapIndexedDataset" in str(type(self.data)):
dix = self.data.get(idx=0, offset=i, length=req_len).astype(int)
elif "numpy" in str(type(self.data)):

@ -2,7 +2,7 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os, math, gc, time
import os, math, gc
from re import L
import torch
import torch.nn as nn
@ -20,7 +20,7 @@ def __nop(ob):
MyModule = nn.Module
MyFunction = __nop
if os.environ["RWKV_JIT"] == "1":
if os.environ["RWKV_JIT_ON"] == "1":
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
@ -273,9 +273,31 @@ class RWKV(pl.LightningModule):
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
def configure_optimizers(self):
lr_1x = set()
lr_2x = set()
lr_3x = set()
for n, p in self.named_parameters():
if ("time_mix" in n) and (self.args.my_pile_mode == 2):
lr_2x.add(n)
elif "time_decay" in n:
lr_2x.add(n)
elif "time_first" in n:
lr_3x.add(n)
else:
lr_1x.add(n)
lr_1x = sorted(list(lr_1x))
lr_2x = sorted(list(lr_2x))
lr_3x = sorted(list(lr_3x))
# print('1x', lr_1x)
# print('2x', lr_2x)
# print('3x', lr_3x)
param_dict = {n: p for n, p in self.named_parameters()}
optim_groups = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
]
if self.deepspeed_offload:
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
@ -326,25 +348,13 @@ class RWKV(pl.LightningModule):
idx, targets = batch
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
if self.trainer.global_rank == 0:
t_now = time.time_ns()
try:
t_cost = (t_now - self.trainer.my_time_ns) / 1e9
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True)
except:
pass
self.trainer.my_time_ns = t_now
self.trainer.my_loss = loss.item()
self.trainer.my_loss_sum += self.trainer.my_loss
self.trainer.my_loss_count += 1
self.trainer.my_epoch_loss = self.trainer.my_loss_sum / self.trainer.my_loss_count
self.log("lr", self.trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", self.trainer.my_epoch_loss, prog_bar=True, on_step=True)
return L2Wrap.apply(loss, logits)
def training_step_end(self, batch_parts):
all = self.all_gather(batch_parts)
if self.trainer.is_global_zero:
self.trainer.my_loss_all = all
def generate_init_weight(self):
print(
f"""

@ -0,0 +1,50 @@
import random
def MaybeIsPrime(number):
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
return True
else:
return False
def FermatPrimalityTest(number):
if number > 1:
for time in range(3):
randomNumber = random.randint(2, number) - 1
if pow(randomNumber, number - 1, number) != 1:
return False
return True
else:
return False
def MillerRabinPrimalityTest(number):
if number == 2:
return True
elif number == 1 or number % 2 == 0:
return False
oddPartOfNumber = number - 1
timesTwoDividNumber = 0
while oddPartOfNumber % 2 == 0:
oddPartOfNumber = oddPartOfNumber // 2
timesTwoDividNumber = timesTwoDividNumber + 1
for time in range(3):
while True:
randomNumber = random.randint(2, number) - 1
if randomNumber != 0 and randomNumber != 1:
break
randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
iterationNumber = 1
while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1):
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
iterationNumber = iterationNumber + 1
if randomNumberWithPower != (number - 1):
return False
return True

@ -4,10 +4,11 @@
if __name__ == "__main__":
print("\n!!! NOTE: THIS IS STILL WIP !!!\n")
import os, warnings, math, datetime, sys
import os, warnings, math, datetime, sys, time
import numpy as np
from argparse import ArgumentParser
import torch
from torch.utils.data import DataLoader
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import Trainer
@ -83,10 +84,12 @@ if __name__ == "__main__":
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_mode", default=0, type=int) # my special pile mode
args = parser.parse_args()
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
args.enable_checkpointing = False
args.replace_sampler_ddp = False
args.logger = False
args.gradient_clip_val = 1.0
args.num_sanity_val_steps = 0
@ -95,6 +98,12 @@ if __name__ == "__main__":
args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2)
if args.my_pile_mode > 0:
args.epoch_steps = 40320 // (int(args.devices) * args.micro_bsz)
assert args.epoch_steps * int(args.devices) * args.micro_bsz == 40320
if args.my_pile_mode == 2:
assert args.lr_final == args.lr_init
samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len
rank_zero_info(
@ -138,9 +147,9 @@ if __name__ == "__main__":
if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
os.environ["RWKV_JIT"] = "1"
os.environ["RWKV_JIT_ON"] = "1"
if "deepspeed_stage_3" in args.strategy:
os.environ["RWKV_JIT"] = "0"
os.environ["RWKV_JIT_ON"] = "0"
import torch
@ -170,9 +179,37 @@ if __name__ == "__main__":
args = self.args
g_step = trainer.global_step
# logging
if trainer.global_rank == 0:
if g_step == 0:
# LR schedule
w_step = args.warmup_steps
if g_step < w_step:
lr = args.lr_init * (g_step / w_step)
else:
if args.lr_final == args.lr_init:
lr = args.lr_init
else:
progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1)
progress = min(1, max(0, progress))
if args.lr_final == 0 or args.lr_init == 0: # linear decay
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
for param_group in trainer.optimizers[0].param_groups:
if self.args.my_pile_mode == 0:
param_group["lr"] = lr * param_group["my_lr_scale"]
elif self.args.my_pile_mode == 2:
if param_group["my_lr_scale"] > 1:
param_group["lr"] = lr * 5
else:
param_group["lr"] = lr
# print(param_group["lr"], param_group["my_lr_scale"])
trainer.my_lr = lr
# rank_zero_info(f"{g_step} {lr}")
if g_step == 0:
if trainer.is_global_zero: # logging
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
@ -196,39 +233,42 @@ if __name__ == "__main__":
)
trainer.my_wandb = wandb
# LR schedule
w_step = args.warmup_steps
if g_step < w_step:
lr = args.lr_init * (g_step / w_step)
else:
progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1)
progress = min(1, max(0, progress))
if args.lr_final == 0 or args.lr_init == 0: # linear decay
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
for param_group in trainer.optimizers[0].param_groups:
param_group["lr"] = lr
trainer.my_lr = lr
# rank_zero_info(f"{g_step} {lr}")
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
args = self.args
# logging
if trainer.global_rank == 0:
if trainer.is_global_zero: # logging
t_now = time.time_ns()
try:
t_cost = (t_now - trainer.my_time_ns) / 1e9
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True)
except:
pass
trainer.my_time_ns = t_now
trainer.my_loss = trainer.my_loss_all.float().mean().item()
trainer.my_loss_sum += trainer.my_loss
trainer.my_loss_count += 1
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
if len(args.wandb) > 0:
trainer.my_wandb.log(
{"loss": trainer.my_loss, "lr": trainer.my_lr},
step=trainer.global_step,
)
def on_train_epoch_start(self, trainer, pl_module):
args = self.args
dataset = trainer.train_dataloader.dataset.datasets
assert "MyDataset" in str(dataset)
dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size
def on_train_epoch_end(self, trainer, pl_module):
args = self.args
if trainer.global_rank == 0:
if trainer.current_epoch % args.epoch_save == 0 or trainer.current_epoch == args.epoch_count - 1:
if trainer.is_global_zero: # logging & save state_dict
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
torch.save(
pl_module.state_dict(),
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
@ -251,7 +291,6 @@ if __name__ == "__main__":
########################################################################################################
from torch.utils.data import DataLoader
from src.dataset import MyDataset
from src.model import RWKV
@ -261,8 +300,8 @@ if __name__ == "__main__":
model = RWKV(args)
if len(args.load_model) == 0:
args.load_model = f"{args.proj_dir}/rwkv-init.pth" # init weights to tmp file
generate_init_weight(model, args.load_model)
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
generate_init_weight(model, args.load_model) # save initial weights to tmp file
print(f"########## Loading {args.load_model}... ##########")
load_dict = torch.load(args.load_model, map_location="cpu")
@ -273,5 +312,7 @@ if __name__ == "__main__":
callbacks=[train_callback(args)],
)
train_loader = DataLoader(train_data, batch_size=args.micro_bsz, num_workers=1)
trainer.fit(model, train_loader)
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
trainer.fit(model, data_loader)

Loading…
Cancel
Save