main
BlinkDL 3 years ago
parent 73b96705d7
commit c84e8fd952

@ -207,7 +207,7 @@ if __name__ == '__main__':
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=16) trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=16)
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision='bf16') trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision='bf16')
elif os.environ['RWKV_FLOAT_MODE'] == 'fp32': elif '32' in os.environ['RWKV_FLOAT_MODE']:
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=32) trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=32)
else: else:
from pytorch_lightning.strategies import DeepSpeedStrategy from pytorch_lightning.strategies import DeepSpeedStrategy
@ -270,7 +270,7 @@ if __name__ == '__main__':
} }
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16') trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16')
elif os.environ['RWKV_FLOAT_MODE'] == 'fp32': elif '32' in os.environ['RWKV_FLOAT_MODE']:
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=32) trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=32)
print(trainer._strategy.config) print(trainer._strategy.config)

Loading…
Cancel
Save