diff --git a/RWKV-v4/src/binidx.py b/RWKV-v4/src/binidx.py index 43fefaa..ce6cfe2 100644 --- a/RWKV-v4/src/binidx.py +++ b/RWKV-v4/src/binidx.py @@ -27,7 +27,7 @@ dtypes = { 3: np.int16, 4: np.int32, 5: np.int64, - 6: np.float, + 6: float, 7: np.double, 8: np.uint16, } diff --git a/RWKV-v4neo/src/binidx.py b/RWKV-v4neo/src/binidx.py index f8365f3..369081a 100644 --- a/RWKV-v4neo/src/binidx.py +++ b/RWKV-v4neo/src/binidx.py @@ -28,7 +28,7 @@ dtypes = { 3: np.int16, 4: np.int32, 5: np.int64, - 6: np.float, + 6: float, 7: np.double, 8: np.uint16, } diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index d148d90..0460879 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -239,7 +239,8 @@ if __name__ == "__main__": assert args.precision in ["fp32", "tf32", "fp16", "bf16"] os.environ["RWKV_FLOAT_MODE"] = args.precision if args.precision == "fp32": - rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n") + for i in range(10): + rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n") if args.precision == "fp16": rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")