main
BlinkDL 3 years ago
parent 58e9d8d972
commit 9f557219c4

@ -27,7 +27,7 @@ dtypes = {
3: np.int16, 3: np.int16,
4: np.int32, 4: np.int32,
5: np.int64, 5: np.int64,
6: np.float, 6: float,
7: np.double, 7: np.double,
8: np.uint16, 8: np.uint16,
} }

@ -28,7 +28,7 @@ dtypes = {
3: np.int16, 3: np.int16,
4: np.int32, 4: np.int32,
5: np.int64, 5: np.int64,
6: np.float, 6: float,
7: np.double, 7: np.double,
8: np.uint16, 8: np.uint16,
} }

@ -239,7 +239,8 @@ if __name__ == "__main__":
assert args.precision in ["fp32", "tf32", "fp16", "bf16"] assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision os.environ["RWKV_FLOAT_MODE"] = args.precision
if args.precision == "fp32": if args.precision == "fp32":
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n") for i in range(10):
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
if args.precision == "fp16": if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")

Loading…
Cancel
Save