|
|
|
|
@ -3,7 +3,7 @@
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
print("\n!!! NOTE: THIS IS STILL WIP !!!\n")
|
|
|
|
|
print("\n!!! NOTE: THIS IS STILL WIP (and a bit slower than RWKV-4) !!!\n")
|
|
|
|
|
import os, warnings, math, datetime
|
|
|
|
|
import numpy as np
|
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
|
@ -16,19 +16,6 @@ if __name__ == "__main__":
|
|
|
|
|
from pytorch_lightning.callbacks import TQDMProgressBar
|
|
|
|
|
from pytorch_lightning import Callback
|
|
|
|
|
|
|
|
|
|
rank_zero_info(
|
|
|
|
|
f"""
|
|
|
|
|
############################################################################
|
|
|
|
|
#
|
|
|
|
|
# torch {torch.__version__}, recommend 1.12.1+cu116 or newer
|
|
|
|
|
#
|
|
|
|
|
# deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer
|
|
|
|
|
#
|
|
|
|
|
# pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
|
|
|
|
|
#
|
|
|
|
|
############################################################################
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
seed_everything(42)
|
|
|
|
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
|
|
|
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
|
|
|
|
|
@ -87,7 +74,7 @@ if __name__ == "__main__":
|
|
|
|
|
f"""
|
|
|
|
|
############################################################################
|
|
|
|
|
#
|
|
|
|
|
# RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()} {args.strateg.upper()} {'with grad_cp' if args.grad_cp > 0 else ''}
|
|
|
|
|
# RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()} {args.strategy.upper()} {'with grad_cp' if args.grad_cp > 0 else ''}
|
|
|
|
|
#
|
|
|
|
|
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
|
|
|
|
|
#
|
|
|
|
|
@ -99,6 +86,10 @@ if __name__ == "__main__":
|
|
|
|
|
#
|
|
|
|
|
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, β {args.betas}, eps {args.adam_eps}
|
|
|
|
|
#
|
|
|
|
|
# torch {torch.__version__}, recommend 1.12.1+cu116 or newer
|
|
|
|
|
# deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer
|
|
|
|
|
# pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
|
|
|
|
|
#
|
|
|
|
|
############################################################################
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
|