code for pile training

main
BlinkDL 3 years ago
parent d1674732ed
commit 99a3dff414

4
.gitignore vendored

@ -5,6 +5,10 @@
*.xlsx
*.xls
wandb/
data/
vocab.json
*.sh
*log/
# Byte-compiled / optimized / DLL files
__pycache__/

@ -22,11 +22,11 @@ class MyDataset(Dataset):
self.data_size = len(self.data._bin_buffer) // 2
print(f"Data has {self.data_size} tokens.")
if args.my_pile_mode > 0:
if args.my_pile_stage > 0:
assert self.data_size == 332115325534 and self.vocab_size == 50277 and args.ctx_len == 1024
self.samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
assert self.samples_per_epoch == 40320
print(f"########## Pile 20b-tokenized mode {args.my_pile_mode} ##########")
print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
self.magic_prime = 324331313
dataset_slot = self.data_size // args.ctx_len
assert MaybeIsPrime(self.magic_prime)
@ -46,7 +46,7 @@ class MyDataset(Dataset):
aa = (i) % 10000
bb = (i * i) % 10000
cc = aa + bb
self.data += f'.{aa}+{bb}={cc}.'
self.data += f".{aa}+{bb}={cc}."
else:
self.data = open(args.data_file, "r", encoding=args.data_type).read()
print("Building token list...")
@ -84,7 +84,7 @@ class MyDataset(Dataset):
ctx_len = args.ctx_len
req_len = ctx_len + 1
if args.my_pile_mode > 0:
if args.my_pile_stage > 0:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
factor = (math.sqrt(5) - 1) / 2
factor = int(self.magic_prime * factor)

@ -12,8 +12,10 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
def __nop(ob):
return ob
@ -278,7 +280,7 @@ class RWKV(pl.LightningModule):
lr_2x = set()
lr_3x = set()
for n, p in self.named_parameters():
if ("time_mix" in n) and (self.args.my_pile_mode == 2):
if ("time_mix" in n) and (self.args.my_pile_stage == 2):
lr_2x.add(n)
elif "time_decay" in n:
lr_2x.add(n)
@ -382,7 +384,7 @@ class RWKV(pl.LightningModule):
m[n] = p
else:
if n == "emb.weight":
scale = -25 * self.args.lr_init
scale = -1 * self.args.lr_init
else:
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
@ -406,7 +408,7 @@ class RWKV(pl.LightningModule):
if scale == 0:
nn.init.zeros_(m[n])
elif scale < 0:
nn.init.normal_(m[n], mean=0.0, std=-scale)
nn.init.uniform_(m[n], a=scale, b=-scale)
else:
nn.init.orthogonal_(m[n], gain=gain * scale)

@ -2,9 +2,8 @@ import os, math, time, datetime
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.callbacks import TQDMProgressBar
class train_callback(pl.Callback):
def __init__(self, args):
@ -33,9 +32,9 @@ class train_callback(pl.Callback):
for param_group in trainer.optimizers[0].param_groups:
if args.layerwise_lr > 0:
if self.args.my_pile_mode == 0:
if self.args.my_pile_stage != 2:
param_group["lr"] = lr * param_group["my_lr_scale"]
elif self.args.my_pile_mode == 2:
else:
if param_group["my_lr_scale"] > 1:
param_group["lr"] = lr * 5
else:
@ -63,10 +62,10 @@ class train_callback(pl.Callback):
print("Login to wandb...")
import wandb
model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd)
model_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
wandb.init(
project=args.wandb,
name=model_name + "-" + args.my_timestamp,
name=model_name + " " + args.my_timestamp,
config=args,
save_code=False,
)
@ -76,10 +75,12 @@ class train_callback(pl.Callback):
args = self.args
if trainer.is_global_zero: # logging
t_now = time.time_ns()
token_per_step = args.ctx_len * float(args.devices) * args.micro_bsz
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
try:
t_cost = (t_now - trainer.my_time_ns) / 1e9
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True)
self.log("Kt/s", token_per_step / t_cost / 1000, prog_bar=True, on_step=True)
except:
pass
trainer.my_time_ns = t_now
@ -89,11 +90,12 @@ class train_callback(pl.Callback):
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
# self.log("s", real_step, prog_bar=True, on_step=True)
if len(args.wandb) > 0:
trainer.my_wandb.log(
{"loss": trainer.my_loss, "lr": trainer.my_lr},
step=trainer.global_step,
{"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9},
step=int(real_step),
)
def on_train_epoch_start(self, trainer, pl_module):
@ -120,11 +122,19 @@ class train_callback(pl.Callback):
@rank_zero_only
def generate_init_weight(model, temp_name):
try:
os.remove(temp_name)
except:
pass
def generate_init_weight(model, init_weight_name):
mm = model.generate_init_weight()
print(f"Saving to {temp_name}...")
torch.save(mm, temp_name)
if model.args.my_pile_stage == 1:
print(f"Combine weights from {model.args.load_model}...")
load_dict = torch.load(model.args.load_model, map_location="cpu")
for k in load_dict:
assert k in mm
mm[k] = load_dict[k].reshape(mm[k].shape)
print(f"Save to {init_weight_name}...")
torch.save(mm, init_weight_name)
if model.args.my_pile_stage == 1:
print("Done. Now go for stage 2.")
exit(0)

@ -3,7 +3,7 @@
########################################################################################################
if __name__ == "__main__":
print("\n!!! NOTE: THIS IS STILL WIP !!!\n")
print("\n!!! work in progress !!!\n")
import os, warnings, math, datetime, sys, time
import numpy as np
from argparse import ArgumentParser
@ -23,7 +23,7 @@ if __name__ == "__main__":
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
########################################################################################################
#
# example: train a simple L12-D768 RWKV on dummy data
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
@ -91,7 +91,7 @@ if __name__ == "__main__":
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_mode", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
@ -107,11 +107,32 @@ if __name__ == "__main__":
args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2)
if args.my_pile_mode > 0:
if args.my_pile_stage > 0:
args.epoch_steps = 40320 // (int(args.devices) * args.micro_bsz)
assert args.epoch_steps * int(args.devices) * args.micro_bsz == 40320
if args.my_pile_mode == 2:
if args.my_pile_stage == 2:
assert args.lr_final == args.lr_init
if args.my_pile_stage >= 2: # find latest saved model
pths = os.listdir(args.proj_dir)
max_p = -1
for p in pths:
if p.startswith("rwkv") and p.endswith(".pth"):
p = ((p.split("-"))[1].split("."))[0]
if p == "init":
p = -1
else:
p = int(p)
if p > max_p:
max_p = p
if max_p == -1:
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
if args.my_pile_stage == 2:
args.warmup_steps = 10
else:
args.warmup_steps = 50
args.epoch_begin = max_p + 1
samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len
@ -185,9 +206,10 @@ if __name__ == "__main__":
model = RWKV(args)
if len(args.load_model) == 0:
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
generate_init_weight(model, args.load_model) # save initial weights to tmp file
if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
generate_init_weight(model, init_weight_name) # save initial weights
args.load_model = init_weight_name
print(f"########## Loading {args.load_model}... ##########")
load_dict = torch.load(args.load_model, map_location="cpu")

Loading…
Cancel
Save