BlinkDL 3 years ago
parent 7b92a979d8
commit 1189c9e238

@ -23,15 +23,14 @@ class MyDataset(Dataset):
print(f"Data has {self.data_size} tokens.")
if args.my_pile_stage > 0:
assert self.data_size == 332115325534 and self.vocab_size == 50277 and args.ctx_len == 1024
assert self.data_size == 332115325534 and self.vocab_size == 50277
self.samples_per_epoch = args.epoch_steps * args.real_bsz
assert self.samples_per_epoch == 40320
print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
self.magic_prime = 324331313
dataset_slot = self.data_size // args.ctx_len
assert MaybeIsPrime(self.magic_prime)
assert self.magic_prime % 3 == 2
assert self.magic_prime / dataset_slot > 0.999999 and self.magic_prime / dataset_slot <= 1
assert MaybeIsPrime(args.magic_prime)
assert args.magic_prime % 3 == 2
assert args.magic_prime / dataset_slot > 0.999999 and args.magic_prime / dataset_slot <= 1
elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size
@ -87,8 +86,9 @@ class MyDataset(Dataset):
if args.my_pile_stage > 0:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
factor = (math.sqrt(5) - 1) / 2
factor = int(self.magic_prime * factor)
i = ((factor * ii * ii * ii) % self.magic_prime) * ctx_len
factor = int(args.magic_prime * factor)
i = ((factor * ii * ii * ii) % args.magic_prime) * ctx_len
i = i + args.my_pile_shift
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
else:
i = np.random.randint(0, self.data_size - req_len)

@ -31,12 +31,12 @@ if os.environ["RWKV_JIT_ON"] == "1":
# CUDA Kernel
########################################################################################################
T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", f"-DTmax={T_MAX}"])
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", f"-DTmax={T_MAX}"])
class WKV(torch.autograd.Function):

@ -14,17 +14,17 @@ class train_callback(pl.Callback):
args = self.args
# if args.cuda_cleanup > 0:
# torch.cuda.empty_cache()
g_step = trainer.global_step
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
# LR schedule
w_step = args.warmup_steps
if g_step < w_step:
lr = args.lr_init * (g_step / w_step)
if trainer.global_step < w_step:
lr = args.lr_init * (trainer.global_step / w_step)
else:
if args.lr_final == args.lr_init:
if args.lr_final == args.lr_init or args.epoch_count == 0:
lr = args.lr_init
else:
progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1)
progress = (real_step - w_step + 1) / (args.epoch_count * args.epoch_steps - w_step)
progress = min(1, max(0, progress))
if args.lr_final == 0 or args.lr_init == 0: # linear decay
@ -40,9 +40,9 @@ class train_callback(pl.Callback):
param_group["lr"] = lr
trainer.my_lr = lr
# rank_zero_info(f"{g_step} {lr}")
# rank_zero_info(f"{real_step} {lr}")
if g_step == 0:
if trainer.global_step == 0:
if trainer.is_global_zero: # logging
trainer.my_loss_sum = 0
trainer.my_loss_count = 0

@ -92,6 +92,7 @@ if __name__ == "__main__":
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
@ -108,8 +109,27 @@ if __name__ == "__main__":
args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.devices) * args.micro_bsz
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
if not os.path.exists(args.proj_dir):
os.makedirs(args.proj_dir)
if args.my_pile_stage > 0:
if args.ctx_len == 1024:
args.magic_prime = 324331313
elif args.ctx_len == 2048:
args.magic_prime = 162165671
elif args.ctx_len == 4096:
args.magic_prime = 81082817
if args.my_pile_shift < 0:
if args.ctx_len == 1024:
args.my_pile_shift = 0
elif args.ctx_len == 2048:
args.my_pile_shift = 512
elif args.ctx_len == 4096:
args.my_pile_shift = 768
args.epoch_count = 8043
args.epoch_steps = 40320 // args.real_bsz
assert args.epoch_steps * args.real_bsz == 40320
if args.my_pile_stage == 2:
@ -164,9 +184,6 @@ if __name__ == "__main__":
)
rank_zero_info(str(vars(args)) + "\n")
if not os.path.exists(args.proj_dir):
os.makedirs(args.proj_dir)
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy"]
if args.lr_final == 0 or args.lr_init == 0:
@ -218,7 +235,7 @@ if __name__ == "__main__":
try:
load_dict = torch.load(args.load_model, map_location="cpu")
except:
print(f'Bad checkpoint {args.load_model}')
print(f"Bad checkpoint {args.load_model}")
if args.my_pile_stage >= 2: # try again using another checkpoint
max_p = args.my_pile_prev_p
if max_p == -1:
@ -227,7 +244,7 @@ if __name__ == "__main__":
else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
args.epoch_begin = max_p + 1
print(f'Trying {args.load_model}')
print(f"Trying {args.load_model}")
load_dict = torch.load(args.load_model, map_location="cpu")
model.load_state_dict(load_dict)

Loading…
Cancel
Save