From 74ffd9bec713a1a26c861d0c790e7573ddcf23c9 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Fri, 9 Sep 2022 18:50:04 +0000 Subject: [PATCH] multinode --- RWKV-v4neo/src/binidx.py | 13 +++++++------ RWKV-v4neo/src/trainer.py | 3 ++- RWKV-v4neo/train.py | 6 +++--- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/RWKV-v4neo/src/binidx.py b/RWKV-v4neo/src/binidx.py index 43fefaa..404a581 100644 --- a/RWKV-v4neo/src/binidx.py +++ b/RWKV-v4neo/src/binidx.py @@ -8,12 +8,13 @@ from functools import lru_cache from itertools import accumulate def print_rank_0(*message): - """If distributed is initialized print only on rank 0.""" - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - print(*message, flush=True) - else: - print(*message, flush=True) + pass + # """If distributed is initialized print only on rank 0.""" + # if torch.distributed.is_initialized(): + # if torch.distributed.get_rank() == 0: + # print(*message, flush=True) + # else: + # print(*message, flush=True) def _warmup_mmap_file(path): pass diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index 6337030..d42856c 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -71,7 +71,7 @@ class train_callback(pl.Callback): args = self.args if trainer.is_global_zero: # logging t_now = time.time_ns() - token_per_step = args.ctx_len * float(args.devices) * args.micro_bsz + token_per_step = args.ctx_len * args.real_bsz real_step = trainer.global_step + args.epoch_begin * args.epoch_steps try: t_cost = (t_now - trainer.my_time_ns) / 1e9 @@ -101,6 +101,7 @@ class train_callback(pl.Callback): dataset.global_rank = trainer.global_rank dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) dataset.world_size = trainer.world_size + # print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########') def on_train_epoch_end(self, trainer, pl_module): args = self.args diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index bd4847c..1246990 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -3,7 +3,7 @@ ######################################################################################################## if __name__ == "__main__": - print("\n!!! work in progress !!!\n") + print("########## work in progress ##########") import os, warnings, math, datetime, sys, time import numpy as np from argparse import ArgumentParser @@ -108,7 +108,7 @@ if __name__ == "__main__": args.log_every_n_steps = int(1e20) args.max_epochs = -1 # continue forever args.betas = (args.beta1, args.beta2) - args.real_bsz = int(args.devices) * args.micro_bsz + args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz os.environ["RWKV_T_MAX"] = str(args.ctx_len) if not os.path.exists(args.proj_dir): @@ -164,7 +164,7 @@ if __name__ == "__main__": f""" ############################################################################ # -# RWKV-4 {args.precision.upper()} on {args.devices}x{args.accelerator.upper()}, bsz {args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''} +# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''} # # Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} #