code for pile training

main
BlinkDL 3 years ago
parent d1674732ed
commit 99a3dff414

4
.gitignore vendored

@ -5,6 +5,10 @@
*.xlsx *.xlsx
*.xls *.xls
wandb/ wandb/
data/
vocab.json
*.sh
*log/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

@ -22,11 +22,11 @@ class MyDataset(Dataset):
self.data_size = len(self.data._bin_buffer) // 2 self.data_size = len(self.data._bin_buffer) // 2
print(f"Data has {self.data_size} tokens.") print(f"Data has {self.data_size} tokens.")
if args.my_pile_mode > 0: if args.my_pile_stage > 0:
assert self.data_size == 332115325534 and self.vocab_size == 50277 and args.ctx_len == 1024 assert self.data_size == 332115325534 and self.vocab_size == 50277 and args.ctx_len == 1024
self.samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz self.samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
assert self.samples_per_epoch == 40320 assert self.samples_per_epoch == 40320
print(f"########## Pile 20b-tokenized mode {args.my_pile_mode} ##########") print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
self.magic_prime = 324331313 self.magic_prime = 324331313
dataset_slot = self.data_size // args.ctx_len dataset_slot = self.data_size // args.ctx_len
assert MaybeIsPrime(self.magic_prime) assert MaybeIsPrime(self.magic_prime)
@ -46,7 +46,7 @@ class MyDataset(Dataset):
aa = (i) % 10000 aa = (i) % 10000
bb = (i * i) % 10000 bb = (i * i) % 10000
cc = aa + bb cc = aa + bb
self.data += f'.{aa}+{bb}={cc}.' self.data += f".{aa}+{bb}={cc}."
else: else:
self.data = open(args.data_file, "r", encoding=args.data_type).read() self.data = open(args.data_file, "r", encoding=args.data_type).read()
print("Building token list...") print("Building token list...")
@ -84,7 +84,7 @@ class MyDataset(Dataset):
ctx_len = args.ctx_len ctx_len = args.ctx_len
req_len = ctx_len + 1 req_len = ctx_len + 1
if args.my_pile_mode > 0: if args.my_pile_stage > 0:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
factor = (math.sqrt(5) - 1) / 2 factor = (math.sqrt(5) - 1) / 2
factor = int(self.magic_prime * factor) factor = int(self.magic_prime * factor)

@ -12,8 +12,10 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
def __nop(ob): def __nop(ob):
return ob return ob
@ -278,7 +280,7 @@ class RWKV(pl.LightningModule):
lr_2x = set() lr_2x = set()
lr_3x = set() lr_3x = set()
for n, p in self.named_parameters(): for n, p in self.named_parameters():
if ("time_mix" in n) and (self.args.my_pile_mode == 2): if ("time_mix" in n) and (self.args.my_pile_stage == 2):
lr_2x.add(n) lr_2x.add(n)
elif "time_decay" in n: elif "time_decay" in n:
lr_2x.add(n) lr_2x.add(n)
@ -382,7 +384,7 @@ class RWKV(pl.LightningModule):
m[n] = p m[n] = p
else: else:
if n == "emb.weight": if n == "emb.weight":
scale = -25 * self.args.lr_init scale = -1 * self.args.lr_init
else: else:
if shape[0] > shape[1]: if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1]) gain = math.sqrt(shape[0] / shape[1])
@ -406,7 +408,7 @@ class RWKV(pl.LightningModule):
if scale == 0: if scale == 0:
nn.init.zeros_(m[n]) nn.init.zeros_(m[n])
elif scale < 0: elif scale < 0:
nn.init.normal_(m[n], mean=0.0, std=-scale) nn.init.uniform_(m[n], a=scale, b=-scale)
else: else:
nn.init.orthogonal_(m[n], gain=gain * scale) nn.init.orthogonal_(m[n], gain=gain * scale)

@ -2,9 +2,8 @@ import os, math, time, datetime
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.callbacks import TQDMProgressBar
class train_callback(pl.Callback): class train_callback(pl.Callback):
def __init__(self, args): def __init__(self, args):
@ -33,9 +32,9 @@ class train_callback(pl.Callback):
for param_group in trainer.optimizers[0].param_groups: for param_group in trainer.optimizers[0].param_groups:
if args.layerwise_lr > 0: if args.layerwise_lr > 0:
if self.args.my_pile_mode == 0: if self.args.my_pile_stage != 2:
param_group["lr"] = lr * param_group["my_lr_scale"] param_group["lr"] = lr * param_group["my_lr_scale"]
elif self.args.my_pile_mode == 2: else:
if param_group["my_lr_scale"] > 1: if param_group["my_lr_scale"] > 1:
param_group["lr"] = lr * 5 param_group["lr"] = lr * 5
else: else:
@ -63,10 +62,10 @@ class train_callback(pl.Callback):
print("Login to wandb...") print("Login to wandb...")
import wandb import wandb
model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd) model_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
wandb.init( wandb.init(
project=args.wandb, project=args.wandb,
name=model_name + "-" + args.my_timestamp, name=model_name + " " + args.my_timestamp,
config=args, config=args,
save_code=False, save_code=False,
) )
@ -76,10 +75,12 @@ class train_callback(pl.Callback):
args = self.args args = self.args
if trainer.is_global_zero: # logging if trainer.is_global_zero: # logging
t_now = time.time_ns() t_now = time.time_ns()
token_per_step = args.ctx_len * float(args.devices) * args.micro_bsz
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
try: try:
t_cost = (t_now - trainer.my_time_ns) / 1e9 t_cost = (t_now - trainer.my_time_ns) / 1e9
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True) self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True) self.log("Kt/s", token_per_step / t_cost / 1000, prog_bar=True, on_step=True)
except: except:
pass pass
trainer.my_time_ns = t_now trainer.my_time_ns = t_now
@ -89,11 +90,12 @@ class train_callback(pl.Callback):
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True) self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True) self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
# self.log("s", real_step, prog_bar=True, on_step=True)
if len(args.wandb) > 0: if len(args.wandb) > 0:
trainer.my_wandb.log( trainer.my_wandb.log(
{"loss": trainer.my_loss, "lr": trainer.my_lr}, {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9},
step=trainer.global_step, step=int(real_step),
) )
def on_train_epoch_start(self, trainer, pl_module): def on_train_epoch_start(self, trainer, pl_module):
@ -120,11 +122,19 @@ class train_callback(pl.Callback):
@rank_zero_only @rank_zero_only
def generate_init_weight(model, temp_name): def generate_init_weight(model, init_weight_name):
try:
os.remove(temp_name)
except:
pass
mm = model.generate_init_weight() mm = model.generate_init_weight()
print(f"Saving to {temp_name}...")
torch.save(mm, temp_name) if model.args.my_pile_stage == 1:
print(f"Combine weights from {model.args.load_model}...")
load_dict = torch.load(model.args.load_model, map_location="cpu")
for k in load_dict:
assert k in mm
mm[k] = load_dict[k].reshape(mm[k].shape)
print(f"Save to {init_weight_name}...")
torch.save(mm, init_weight_name)
if model.args.my_pile_stage == 1:
print("Done. Now go for stage 2.")
exit(0)

@ -3,7 +3,7 @@
######################################################################################################## ########################################################################################################
if __name__ == "__main__": if __name__ == "__main__":
print("\n!!! NOTE: THIS IS STILL WIP !!!\n") print("\n!!! work in progress !!!\n")
import os, warnings, math, datetime, sys, time import os, warnings, math, datetime, sys, time
import numpy as np import numpy as np
from argparse import ArgumentParser from argparse import ArgumentParser
@ -23,7 +23,7 @@ if __name__ == "__main__":
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
######################################################################################################## ########################################################################################################
#
# example: train a simple L12-D768 RWKV on dummy data # example: train a simple L12-D768 RWKV on dummy data
# #
# python train.py --load_model "" --wandb "" --proj_dir "out" \ # python train.py --load_model "" --wandb "" --proj_dir "out" \
@ -91,7 +91,7 @@ if __name__ == "__main__":
parser.add_argument("--adam_eps", default=1e-8, type=float) parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_mode", default=0, type=int) # my special pile mode parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
@ -107,11 +107,32 @@ if __name__ == "__main__":
args.max_epochs = -1 # continue forever args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2) args.betas = (args.beta1, args.beta2)
if args.my_pile_mode > 0: if args.my_pile_stage > 0:
args.epoch_steps = 40320 // (int(args.devices) * args.micro_bsz) args.epoch_steps = 40320 // (int(args.devices) * args.micro_bsz)
assert args.epoch_steps * int(args.devices) * args.micro_bsz == 40320 assert args.epoch_steps * int(args.devices) * args.micro_bsz == 40320
if args.my_pile_mode == 2: if args.my_pile_stage == 2:
assert args.lr_final == args.lr_init assert args.lr_final == args.lr_init
if args.my_pile_stage >= 2: # find latest saved model
pths = os.listdir(args.proj_dir)
max_p = -1
for p in pths:
if p.startswith("rwkv") and p.endswith(".pth"):
p = ((p.split("-"))[1].split("."))[0]
if p == "init":
p = -1
else:
p = int(p)
if p > max_p:
max_p = p
if max_p == -1:
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
if args.my_pile_stage == 2:
args.warmup_steps = 10
else:
args.warmup_steps = 50
args.epoch_begin = max_p + 1
samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len tokens_per_epoch = samples_per_epoch * args.ctx_len
@ -175,7 +196,7 @@ if __name__ == "__main__":
args.precision = "bf16" args.precision = "bf16"
######################################################################################################## ########################################################################################################
from src.trainer import train_callback, generate_init_weight from src.trainer import train_callback, generate_init_weight
from src.dataset import MyDataset from src.dataset import MyDataset
from src.model import RWKV from src.model import RWKV
@ -185,9 +206,10 @@ if __name__ == "__main__":
model = RWKV(args) model = RWKV(args)
if len(args.load_model) == 0: if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
args.load_model = f"{args.proj_dir}/rwkv-init.pth" init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
generate_init_weight(model, args.load_model) # save initial weights to tmp file generate_init_weight(model, init_weight_name) # save initial weights
args.load_model = init_weight_name
print(f"########## Loading {args.load_model}... ##########") print(f"########## Loading {args.load_model}... ##########")
load_dict = torch.load(args.load_model, map_location="cpu") load_dict = torch.load(args.load_model, map_location="cpu")

Loading…
Cancel
Save