|
|
|
|
@ -207,7 +207,7 @@ if __name__ == '__main__':
|
|
|
|
|
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=16)
|
|
|
|
|
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
|
|
|
|
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision='bf16')
|
|
|
|
|
elif os.environ['RWKV_FLOAT_MODE'] == 'fp32':
|
|
|
|
|
elif '32' in os.environ['RWKV_FLOAT_MODE']:
|
|
|
|
|
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=32)
|
|
|
|
|
else:
|
|
|
|
|
from pytorch_lightning.strategies import DeepSpeedStrategy
|
|
|
|
|
@ -270,7 +270,7 @@ if __name__ == '__main__':
|
|
|
|
|
}
|
|
|
|
|
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16')
|
|
|
|
|
|
|
|
|
|
elif os.environ['RWKV_FLOAT_MODE'] == 'fp32':
|
|
|
|
|
elif '32' in os.environ['RWKV_FLOAT_MODE']:
|
|
|
|
|
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=32)
|
|
|
|
|
|
|
|
|
|
print(trainer._strategy.config)
|
|
|
|
|
|