main
BlinkDL 3 years ago
parent 58e9d8d972
commit 9f557219c4

@ -27,7 +27,7 @@ dtypes = {
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
6: float,
7: np.double,
8: np.uint16,
}

@ -28,7 +28,7 @@ dtypes = {
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
6: float,
7: np.double,
8: np.uint16,
}

@ -239,6 +239,7 @@ if __name__ == "__main__":
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision
if args.precision == "fp32":
for i in range(10):
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")

Loading…
Cancel
Save