main
BlinkDL 3 years ago
parent 73b96705d7
commit c84e8fd952

@ -207,7 +207,7 @@ if __name__ == '__main__':
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=16)
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision='bf16')
elif os.environ['RWKV_FLOAT_MODE'] == 'fp32':
elif '32' in os.environ['RWKV_FLOAT_MODE']:
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=32)
else:
from pytorch_lightning.strategies import DeepSpeedStrategy
@ -270,7 +270,7 @@ if __name__ == '__main__':
}
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16')
elif os.environ['RWKV_FLOAT_MODE'] == 'fp32':
elif '32' in os.environ['RWKV_FLOAT_MODE']:
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=32)
print(trainer._strategy.config)

Loading…
Cancel
Save