|
|
|
@ -3,7 +3,7 @@
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|
print("\n!!! NOTE: THIS IS STILL WIP !!!\n")
|
|
|
|
print("\n!!! work in progress !!!\n")
|
|
|
|
import os, warnings, math, datetime, sys, time
|
|
|
|
import os, warnings, math, datetime, sys, time
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
@ -23,7 +23,7 @@ if __name__ == "__main__":
|
|
|
|
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
|
|
|
|
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
|
|
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
#
|
|
|
|
# example: train a simple L12-D768 RWKV on dummy data
|
|
|
|
# example: train a simple L12-D768 RWKV on dummy data
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
|
|
|
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
|
|
|
@ -91,7 +91,7 @@ if __name__ == "__main__":
|
|
|
|
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
|
|
|
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
|
|
|
|
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
|
|
|
|
parser.add_argument("--my_pile_mode", default=0, type=int) # my special pile mode
|
|
|
|
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
|
|
|
|
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
|
|
|
|
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
|
|
|
|
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
|
|
|
|
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
|
|
|
|
|
|
|
|
|
|
|
|
@ -107,11 +107,32 @@ if __name__ == "__main__":
|
|
|
|
args.max_epochs = -1 # continue forever
|
|
|
|
args.max_epochs = -1 # continue forever
|
|
|
|
args.betas = (args.beta1, args.beta2)
|
|
|
|
args.betas = (args.beta1, args.beta2)
|
|
|
|
|
|
|
|
|
|
|
|
if args.my_pile_mode > 0:
|
|
|
|
if args.my_pile_stage > 0:
|
|
|
|
args.epoch_steps = 40320 // (int(args.devices) * args.micro_bsz)
|
|
|
|
args.epoch_steps = 40320 // (int(args.devices) * args.micro_bsz)
|
|
|
|
assert args.epoch_steps * int(args.devices) * args.micro_bsz == 40320
|
|
|
|
assert args.epoch_steps * int(args.devices) * args.micro_bsz == 40320
|
|
|
|
if args.my_pile_mode == 2:
|
|
|
|
if args.my_pile_stage == 2:
|
|
|
|
assert args.lr_final == args.lr_init
|
|
|
|
assert args.lr_final == args.lr_init
|
|
|
|
|
|
|
|
if args.my_pile_stage >= 2: # find latest saved model
|
|
|
|
|
|
|
|
pths = os.listdir(args.proj_dir)
|
|
|
|
|
|
|
|
max_p = -1
|
|
|
|
|
|
|
|
for p in pths:
|
|
|
|
|
|
|
|
if p.startswith("rwkv") and p.endswith(".pth"):
|
|
|
|
|
|
|
|
p = ((p.split("-"))[1].split("."))[0]
|
|
|
|
|
|
|
|
if p == "init":
|
|
|
|
|
|
|
|
p = -1
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
p = int(p)
|
|
|
|
|
|
|
|
if p > max_p:
|
|
|
|
|
|
|
|
max_p = p
|
|
|
|
|
|
|
|
if max_p == -1:
|
|
|
|
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
|
|
|
|
|
|
|
if args.my_pile_stage == 2:
|
|
|
|
|
|
|
|
args.warmup_steps = 10
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
args.warmup_steps = 50
|
|
|
|
|
|
|
|
args.epoch_begin = max_p + 1
|
|
|
|
|
|
|
|
|
|
|
|
samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
|
|
|
|
samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
|
|
|
|
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
|
|
|
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
|
|
|
@ -175,7 +196,7 @@ if __name__ == "__main__":
|
|
|
|
args.precision = "bf16"
|
|
|
|
args.precision = "bf16"
|
|
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
from src.trainer import train_callback, generate_init_weight
|
|
|
|
from src.trainer import train_callback, generate_init_weight
|
|
|
|
from src.dataset import MyDataset
|
|
|
|
from src.dataset import MyDataset
|
|
|
|
from src.model import RWKV
|
|
|
|
from src.model import RWKV
|
|
|
|
@ -185,9 +206,10 @@ if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
model = RWKV(args)
|
|
|
|
model = RWKV(args)
|
|
|
|
|
|
|
|
|
|
|
|
if len(args.load_model) == 0:
|
|
|
|
if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
|
|
|
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
|
|
|
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
|
|
|
|
generate_init_weight(model, args.load_model) # save initial weights to tmp file
|
|
|
|
generate_init_weight(model, init_weight_name) # save initial weights
|
|
|
|
|
|
|
|
args.load_model = init_weight_name
|
|
|
|
|
|
|
|
|
|
|
|
print(f"########## Loading {args.load_model}... ##########")
|
|
|
|
print(f"########## Loading {args.load_model}... ##########")
|
|
|
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
|
|
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
|
|
|
|