multinode

main
BlinkDL 3 years ago
parent 470ac7d1fa
commit 74ffd9bec7

@ -8,12 +8,13 @@ from functools import lru_cache
from itertools import accumulate
def print_rank_0(*message):
"""If distributed is initialized print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(*message, flush=True)
else:
print(*message, flush=True)
pass
# """If distributed is initialized print only on rank 0."""
# if torch.distributed.is_initialized():
# if torch.distributed.get_rank() == 0:
# print(*message, flush=True)
# else:
# print(*message, flush=True)
def _warmup_mmap_file(path):
pass

@ -71,7 +71,7 @@ class train_callback(pl.Callback):
args = self.args
if trainer.is_global_zero: # logging
t_now = time.time_ns()
token_per_step = args.ctx_len * float(args.devices) * args.micro_bsz
token_per_step = args.ctx_len * args.real_bsz
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
try:
t_cost = (t_now - trainer.my_time_ns) / 1e9
@ -101,6 +101,7 @@ class train_callback(pl.Callback):
dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
def on_train_epoch_end(self, trainer, pl_module):
args = self.args

@ -3,7 +3,7 @@
########################################################################################################
if __name__ == "__main__":
print("\n!!! work in progress !!!\n")
print("########## work in progress ##########")
import os, warnings, math, datetime, sys, time
import numpy as np
from argparse import ArgumentParser
@ -108,7 +108,7 @@ if __name__ == "__main__":
args.log_every_n_steps = int(1e20)
args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.devices) * args.micro_bsz
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
if not os.path.exists(args.proj_dir):
@ -164,7 +164,7 @@ if __name__ == "__main__":
f"""
############################################################################
#
# RWKV-4 {args.precision.upper()} on {args.devices}x{args.accelerator.upper()}, bsz {args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
#
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
#

Loading…
Cancel
Save