From c84e8fd952b09e52837329b4a4f259332477159f Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Wed, 31 Aug 2022 10:29:42 +0800 Subject: [PATCH] bugfix --- RWKV-v4/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/RWKV-v4/train.py b/RWKV-v4/train.py index fb0ff1b..cf1ccc4 100644 --- a/RWKV-v4/train.py +++ b/RWKV-v4/train.py @@ -207,7 +207,7 @@ if __name__ == '__main__': trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=16) elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision='bf16') - elif os.environ['RWKV_FLOAT_MODE'] == 'fp32': + elif '32' in os.environ['RWKV_FLOAT_MODE']: trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=32) else: from pytorch_lightning.strategies import DeepSpeedStrategy @@ -270,7 +270,7 @@ if __name__ == '__main__': } trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16') - elif os.environ['RWKV_FLOAT_MODE'] == 'fp32': + elif '32' in os.environ['RWKV_FLOAT_MODE']: trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=32) print(trainer._strategy.config)