clean code

main
BlinkDL 3 years ago
parent 2388bca7c3
commit d1674732ed

@ -26,7 +26,7 @@ class MyDataset(Dataset):
assert self.data_size == 332115325534 and self.vocab_size == 50277 and args.ctx_len == 1024
self.samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
assert self.samples_per_epoch == 40320
print("########## Pile 20b-tokenized mode {args.my_pile_mode} ##########")
print(f"########## Pile 20b-tokenized mode {args.my_pile_mode} ##########")
self.magic_prime = 324331313
dataset_slot = self.data_size // args.ctx_len
assert MaybeIsPrime(self.magic_prime)

@ -12,7 +12,7 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
def __nop(ob):
return ob
@ -44,7 +44,7 @@ class WKV(torch.autograd.Function):
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
assert B * C % min(C, 32) == 0
if "32" in os.environ["RWKV_FLOAT_MODE"]:
w = -torch.exp(w.contiguous())
u = u.contiguous()
@ -71,7 +71,7 @@ class WKV(torch.autograd.Function):
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
assert B * C % min(C, 32) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device="cuda").contiguous()
gu = torch.zeros((B, C), device="cuda").contiguous()
@ -306,6 +306,7 @@ class RWKV(pl.LightningModule):
if self.deepspeed_offload:
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
@property
def deepspeed_offload(self) -> bool:

@ -0,0 +1,130 @@
import os, math, time, datetime
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.callbacks import TQDMProgressBar
class train_callback(pl.Callback):
def __init__(self, args):
super().__init__()
self.args = args
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
args = self.args
g_step = trainer.global_step
# LR schedule
w_step = args.warmup_steps
if g_step < w_step:
lr = args.lr_init * (g_step / w_step)
else:
if args.lr_final == args.lr_init:
lr = args.lr_init
else:
progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1)
progress = min(1, max(0, progress))
if args.lr_final == 0 or args.lr_init == 0: # linear decay
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
for param_group in trainer.optimizers[0].param_groups:
if args.layerwise_lr > 0:
if self.args.my_pile_mode == 0:
param_group["lr"] = lr * param_group["my_lr_scale"]
elif self.args.my_pile_mode == 2:
if param_group["my_lr_scale"] > 1:
param_group["lr"] = lr * 5
else:
param_group["lr"] = lr
# print(param_group["lr"], param_group["my_lr_scale"])
else:
param_group["lr"] = lr
trainer.my_lr = lr
# rank_zero_info(f"{g_step} {lr}")
if g_step == 0:
if trainer.is_global_zero: # logging
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
try:
print(f"\n{trainer.strategy.config}\n")
trainer.my_log.write(f"{trainer.strategy.config}\n")
except:
pass
trainer.my_log.flush()
if len(args.wandb) > 0:
print("Login to wandb...")
import wandb
model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd)
wandb.init(
project=args.wandb,
name=model_name + "-" + args.my_timestamp,
config=args,
save_code=False,
)
trainer.my_wandb = wandb
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
args = self.args
if trainer.is_global_zero: # logging
t_now = time.time_ns()
try:
t_cost = (t_now - trainer.my_time_ns) / 1e9
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True)
except:
pass
trainer.my_time_ns = t_now
trainer.my_loss = trainer.my_loss_all.float().mean().item()
trainer.my_loss_sum += trainer.my_loss
trainer.my_loss_count += 1
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
if len(args.wandb) > 0:
trainer.my_wandb.log(
{"loss": trainer.my_loss, "lr": trainer.my_lr},
step=trainer.global_step,
)
def on_train_epoch_start(self, trainer, pl_module):
args = self.args
dataset = trainer.train_dataloader.dataset.datasets
assert "MyDataset" in str(dataset)
dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size
def on_train_epoch_end(self, trainer, pl_module):
args = self.args
if trainer.is_global_zero: # logging & save state_dict
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
torch.save(
pl_module.state_dict(),
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
)
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
trainer.my_log.flush()
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
@rank_zero_only
def generate_init_weight(model, temp_name):
try:
os.remove(temp_name)
except:
pass
mm = model.generate_init_weight()
print(f"Saving to {temp_name}...")
torch.save(mm, temp_name)

@ -14,8 +14,6 @@ if __name__ == "__main__":
from pytorch_lightning import Trainer
from pytorch_lightning import seed_everything
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning import Callback
# print("WARNING: THIS IS ONLY FOR DEBUG")
# seed_everything(42)
@ -95,7 +93,7 @@ if __name__ == "__main__":
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_mode", default=0, type=int) # my special pile mode
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 500 might be faster (but more VRAM)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
args = parser.parse_args()
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
@ -161,8 +159,6 @@ if __name__ == "__main__":
if "deepspeed_stage_3" in args.strategy:
os.environ["RWKV_JIT_ON"] = "0"
import torch
torch.backends.cudnn.benchmark = True
if args.precision == "fp32":
torch.backends.cudnn.allow_tf32 = False
@ -179,131 +175,8 @@ if __name__ == "__main__":
args.precision = "bf16"
########################################################################################################
class train_callback(pl.Callback):
def __init__(self, args):
super().__init__()
self.args = args
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
args = self.args
g_step = trainer.global_step
# LR schedule
w_step = args.warmup_steps
if g_step < w_step:
lr = args.lr_init * (g_step / w_step)
else:
if args.lr_final == args.lr_init:
lr = args.lr_init
else:
progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1)
progress = min(1, max(0, progress))
if args.lr_final == 0 or args.lr_init == 0: # linear decay
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
for param_group in trainer.optimizers[0].param_groups:
if args.layerwise_lr > 0:
if self.args.my_pile_mode == 0:
param_group["lr"] = lr * param_group["my_lr_scale"]
elif self.args.my_pile_mode == 2:
if param_group["my_lr_scale"] > 1:
param_group["lr"] = lr * 5
else:
param_group["lr"] = lr
# print(param_group["lr"], param_group["my_lr_scale"])
else:
param_group["lr"] = lr
trainer.my_lr = lr
# rank_zero_info(f"{g_step} {lr}")
if g_step == 0:
if trainer.is_global_zero: # logging
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
try:
print(f"\n{trainer.strategy.config}\n")
trainer.my_log.write(f"{trainer.strategy.config}\n")
except:
pass
trainer.my_log.flush()
if len(args.wandb) > 0:
print("Login to wandb...")
import wandb
model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd)
wandb.init(
project=args.wandb,
name=model_name + "-" + args.my_timestamp,
config=args,
save_code=False,
)
trainer.my_wandb = wandb
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
args = self.args
if trainer.is_global_zero: # logging
t_now = time.time_ns()
try:
t_cost = (t_now - trainer.my_time_ns) / 1e9
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True)
except:
pass
trainer.my_time_ns = t_now
trainer.my_loss = trainer.my_loss_all.float().mean().item()
trainer.my_loss_sum += trainer.my_loss
trainer.my_loss_count += 1
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
if len(args.wandb) > 0:
trainer.my_wandb.log(
{"loss": trainer.my_loss, "lr": trainer.my_lr},
step=trainer.global_step,
)
def on_train_epoch_start(self, trainer, pl_module):
args = self.args
dataset = trainer.train_dataloader.dataset.datasets
assert "MyDataset" in str(dataset)
dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size
def on_train_epoch_end(self, trainer, pl_module):
args = self.args
if trainer.is_global_zero: # logging & save state_dict
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
torch.save(
pl_module.state_dict(),
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
)
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
trainer.my_log.flush()
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
@rank_zero_only
def generate_init_weight(model, temp_name):
try:
os.remove(temp_name)
except:
pass
mm = model.generate_init_weight()
print(f"Saving to {temp_name}...")
torch.save(mm, temp_name)
########################################################################################################
from src.trainer import train_callback, generate_init_weight
from src.dataset import MyDataset
from src.model import RWKV

Loading…
Cancel
Save