|
|
|
|
@ -239,7 +239,8 @@ if __name__ == "__main__":
|
|
|
|
|
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
|
|
|
|
|
os.environ["RWKV_FLOAT_MODE"] = args.precision
|
|
|
|
|
if args.precision == "fp32":
|
|
|
|
|
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
|
|
|
|
|
for i in range(10):
|
|
|
|
|
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
|
|
|
|
|
if args.precision == "fp16":
|
|
|
|
|
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
|
|
|
|
|
|
|
|
|
|
|