BlinkDL 3 years ago
parent cdb098c0e0
commit 23b0c74950

@ -3,7 +3,7 @@
######################################################################################################## ########################################################################################################
if __name__ == "__main__": if __name__ == "__main__":
print("\n!!! NOTE: THIS IS STILL WIP !!!\n") print("\n!!! NOTE: THIS IS STILL WIP (and a bit slower than RWKV-4) !!!\n")
import os, warnings, math, datetime import os, warnings, math, datetime
import numpy as np import numpy as np
from argparse import ArgumentParser from argparse import ArgumentParser
@ -16,19 +16,6 @@ if __name__ == "__main__":
from pytorch_lightning.callbacks import TQDMProgressBar from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning import Callback from pytorch_lightning import Callback
rank_zero_info(
f"""
############################################################################
#
# torch {torch.__version__}, recommend 1.12.1+cu116 or newer
#
# deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer
#
# pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
#
############################################################################
"""
)
seed_everything(42) seed_everything(42)
np.set_printoptions(precision=4, suppress=True, linewidth=200) np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
@ -87,7 +74,7 @@ if __name__ == "__main__":
f""" f"""
############################################################################ ############################################################################
# #
# RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()} {args.strateg.upper()} {'with grad_cp' if args.grad_cp > 0 else ''} # RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()} {args.strategy.upper()} {'with grad_cp' if args.grad_cp > 0 else ''}
# #
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} # Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
# #
@ -99,6 +86,10 @@ if __name__ == "__main__":
# #
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, β {args.betas}, eps {args.adam_eps} # Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, β {args.betas}, eps {args.adam_eps}
# #
# torch {torch.__version__}, recommend 1.12.1+cu116 or newer
# deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer
# pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
#
############################################################################ ############################################################################
""" """
) )

Loading…
Cancel
Save