multinode

main
BlinkDL 3 years ago
parent 470ac7d1fa
commit 74ffd9bec7

@ -8,12 +8,13 @@ from functools import lru_cache
from itertools import accumulate from itertools import accumulate
def print_rank_0(*message): def print_rank_0(*message):
"""If distributed is initialized print only on rank 0.""" pass
if torch.distributed.is_initialized(): # """If distributed is initialized print only on rank 0."""
if torch.distributed.get_rank() == 0: # if torch.distributed.is_initialized():
print(*message, flush=True) # if torch.distributed.get_rank() == 0:
else: # print(*message, flush=True)
print(*message, flush=True) # else:
# print(*message, flush=True)
def _warmup_mmap_file(path): def _warmup_mmap_file(path):
pass pass

@ -71,7 +71,7 @@ class train_callback(pl.Callback):
args = self.args args = self.args
if trainer.is_global_zero: # logging if trainer.is_global_zero: # logging
t_now = time.time_ns() t_now = time.time_ns()
token_per_step = args.ctx_len * float(args.devices) * args.micro_bsz token_per_step = args.ctx_len * args.real_bsz
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
try: try:
t_cost = (t_now - trainer.my_time_ns) / 1e9 t_cost = (t_now - trainer.my_time_ns) / 1e9
@ -101,6 +101,7 @@ class train_callback(pl.Callback):
dataset.global_rank = trainer.global_rank dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size dataset.world_size = trainer.world_size
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
def on_train_epoch_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module):
args = self.args args = self.args

@ -3,7 +3,7 @@
######################################################################################################## ########################################################################################################
if __name__ == "__main__": if __name__ == "__main__":
print("\n!!! work in progress !!!\n") print("########## work in progress ##########")
import os, warnings, math, datetime, sys, time import os, warnings, math, datetime, sys, time
import numpy as np import numpy as np
from argparse import ArgumentParser from argparse import ArgumentParser
@ -108,7 +108,7 @@ if __name__ == "__main__":
args.log_every_n_steps = int(1e20) args.log_every_n_steps = int(1e20)
args.max_epochs = -1 # continue forever args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2) args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.devices) * args.micro_bsz args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
os.environ["RWKV_T_MAX"] = str(args.ctx_len) os.environ["RWKV_T_MAX"] = str(args.ctx_len)
if not os.path.exists(args.proj_dir): if not os.path.exists(args.proj_dir):
@ -164,7 +164,7 @@ if __name__ == "__main__":
f""" f"""
############################################################################ ############################################################################
# #
# RWKV-4 {args.precision.upper()} on {args.devices}x{args.accelerator.upper()}, bsz {args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''} # RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
# #
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} # Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
# #

Loading…
Cancel
Save