diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 0e28642..4884315 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -18,9 +18,9 @@ class MyDataset(Dataset): if args.data_type == "binidx": self.data = MMapIndexedDataset(args.data_file) self.vocab_size = args.vocab_size - print("current vocab size =", self.vocab_size, "(make sure it's correct)") + print("Current vocab size =", self.vocab_size, "(make sure it's correct)") self.data_size = len(self.data._bin_buffer) // 2 - print(f"data has {self.data_size} tokens.") + print(f"Data has {self.data_size} tokens.") if args.my_pile_mode > 0: assert self.data_size == 332115325534 and self.vocab_size == 50277 and args.ctx_len == 1024 @@ -32,16 +32,24 @@ class MyDataset(Dataset): assert MaybeIsPrime(self.magic_prime) assert self.magic_prime % 3 == 2 assert self.magic_prime / dataset_slot > 0.999999 and self.magic_prime / dataset_slot <= 1 - elif args.data_type == "numpy": self.data = np.load(args.data_file).astype("int") self.vocab_size = args.vocab_size - print("current vocab size =", self.vocab_size, "(make sure it's correct)") + print("Current vocab size =", self.vocab_size, "(make sure it's correct)") self.data_size = len(self.data) - print(f"data has {self.data_size} tokens.") + print(f"Data has {self.data_size} tokens.") else: - self.data = open(args.data_file, "r", encoding=args.data_type).read() - print("building token list...", end=" ") + if args.data_type == "dummy": + print("Building dummy data...") + self.data = "" + for i in range(100000): + aa = (i) % 10000 + bb = (i * i) % 10000 + cc = aa + bb + self.data += f'.{aa}+{bb}={cc}.' + else: + self.data = open(args.data_file, "r", encoding=args.data_type).read() + print("Building token list...") unique = sorted(list(set(self.data))) self.vocab_size = len(unique) # print() @@ -56,7 +64,7 @@ class MyDataset(Dataset): with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file: vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) self.data_size = len(self.data) - print("data has %d tokens, %d unique." % (self.data_size, self.vocab_size)) + print("Data has %d tokens, %d vocab size." % (self.data_size, self.vocab_size)) self.stoi = {ch: i for i, ch in enumerate(unique)} self.itos = {i: ch for i, ch in enumerate(unique)} @@ -67,15 +75,16 @@ class MyDataset(Dataset): # # we are cheating: pick a random spot in dataset # + args = self.args rank = self.global_rank epoch = self.real_epoch world_size = self.world_size # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}") - ctx_len = self.args.ctx_len + ctx_len = args.ctx_len req_len = ctx_len + 1 - if self.args.my_pile_mode > 0: + if args.my_pile_mode > 0: ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank factor = (math.sqrt(5) - 1) / 2 factor = int(self.magic_prime * factor) @@ -84,9 +93,9 @@ class MyDataset(Dataset): else: i = np.random.randint(0, self.data_size - req_len) - if "MMapIndexedDataset" in str(type(self.data)): + if args.data_type == "binidx": dix = self.data.get(idx=0, offset=i, length=req_len).astype(int) - elif "numpy" in str(type(self.data)): + elif args.data_type == "numpy": dix = self.data[i : i + req_len] else: dix = [self.stoi[s] for s in self.data[i : i + req_len]] diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index d29db67..39a6326 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -273,30 +273,35 @@ class RWKV(pl.LightningModule): self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) def configure_optimizers(self): - lr_1x = set() - lr_2x = set() - lr_3x = set() - for n, p in self.named_parameters(): - if ("time_mix" in n) and (self.args.my_pile_mode == 2): - lr_2x.add(n) - elif "time_decay" in n: - lr_2x.add(n) - elif "time_first" in n: - lr_3x.add(n) - else: - lr_1x.add(n) - lr_1x = sorted(list(lr_1x)) - lr_2x = sorted(list(lr_2x)) - lr_3x = sorted(list(lr_3x)) - # print('1x', lr_1x) - # print('2x', lr_2x) - # print('3x', lr_3x) - param_dict = {n: p for n, p in self.named_parameters()} - optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, - ] + if self.args.layerwise_lr > 0: + lr_1x = set() + lr_2x = set() + lr_3x = set() + for n, p in self.named_parameters(): + if ("time_mix" in n) and (self.args.my_pile_mode == 2): + lr_2x.add(n) + elif "time_decay" in n: + lr_2x.add(n) + elif "time_first" in n: + lr_3x.add(n) + else: + lr_1x.add(n) + lr_1x = sorted(list(lr_1x)) + lr_2x = sorted(list(lr_2x)) + lr_3x = sorted(list(lr_3x)) + # print('1x', lr_1x) + # print('2x', lr_2x) + # print('3x', lr_3x) + param_dict = {n: p for n, p in self.named_parameters()} + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, + ] + else: + optim_groups = [ + {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, + ] if self.deepspeed_offload: return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 33ae2d3..4716da0 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -25,8 +25,17 @@ if __name__ == "__main__": warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") ######################################################################################################## + + # example: train a simple L12-D768 RWKV on dummy data + # + # python train.py --load_model "" --wandb "" --proj_dir "out" \ + # --data_file "" --data_type "dummy" --vocab_size 0 \ + # --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \ + # --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \ + # --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 - # example: train a simple L6-D512 RWKV from scratch + # example: train a simple L6-D512 RWKV from scratch on enwik8 # # python train.py --load_model "" --wandb "" --proj_dir "out" \ # --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \ @@ -56,8 +65,8 @@ if __name__ == "__main__": parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) - parser.add_argument("--load_model", default="", type=str) - parser.add_argument("--wandb", default="", type=str) # wandb project name + parser.add_argument("--load_model", default="", type=str) # full path, with .pth + parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb parser.add_argument("--proj_dir", default="out", type=str) parser.add_argument("--data_file", default="", type=str) @@ -65,26 +74,28 @@ if __name__ == "__main__": parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data) parser.add_argument("--ctx_len", default=1024, type=int) - parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has xxx steps - parser.add_argument("--epoch_count", default=500, type=int) - parser.add_argument("--epoch_begin", default=0, type=int) - parser.add_argument("--epoch_save", default=5, type=int) + parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps + parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final + parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x + parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs" parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU) parser.add_argument("--n_layer", default=6, type=int) parser.add_argument("--n_embd", default=512, type=int) - parser.add_argument("--pre_ffn", default=0, type=int) - parser.add_argument("--head_qk", default=0, type=int) + parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better) + parser.add_argument("--head_qk", default=0, type=int) # my headQK trick. try 256 if you want to test it - parser.add_argument("--lr_init", default=6e-4, type=float) + parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 parser.add_argument("--lr_final", default=1e-5, type=float) - parser.add_argument("--warmup_steps", default=0, type=int) + parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model parser.add_argument("--beta1", default=0.9, type=float) - parser.add_argument("--beta2", default=0.99, type=float) + parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence parser.add_argument("--adam_eps", default=1e-8, type=float) parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower parser.add_argument("--my_pile_mode", default=0, type=int) # my special pile mode + parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) + parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 500 might be faster (but more VRAM) args = parser.parse_args() args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") @@ -120,7 +131,7 @@ if __name__ == "__main__": # # Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len # -# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, β {args.betas}, eps {args.adam_eps} +# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} # # Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer # Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions) @@ -134,8 +145,7 @@ if __name__ == "__main__": if not os.path.exists(args.proj_dir): os.makedirs(args.proj_dir) - assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx"] - assert len(args.data_file) > 0 + assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy"] if args.lr_final == 0 or args.lr_init == 0: rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n") @@ -196,14 +206,17 @@ if __name__ == "__main__": lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) for param_group in trainer.optimizers[0].param_groups: - if self.args.my_pile_mode == 0: - param_group["lr"] = lr * param_group["my_lr_scale"] - elif self.args.my_pile_mode == 2: - if param_group["my_lr_scale"] > 1: - param_group["lr"] = lr * 5 - else: - param_group["lr"] = lr - # print(param_group["lr"], param_group["my_lr_scale"]) + if args.layerwise_lr > 0: + if self.args.my_pile_mode == 0: + param_group["lr"] = lr * param_group["my_lr_scale"] + elif self.args.my_pile_mode == 2: + if param_group["my_lr_scale"] > 1: + param_group["lr"] = lr * 5 + else: + param_group["lr"] = lr + # print(param_group["lr"], param_group["my_lr_scale"]) + else: + param_group["lr"] = lr trainer.my_lr = lr # rank_zero_info(f"{g_step} {lr}") @@ -311,6 +324,9 @@ if __name__ == "__main__": args, callbacks=[train_callback(args)], ) + if "deepspeed" in args.strategy: + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 # must set shuffle=False, persistent_workers=False (because worker is in another thread) data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)