diff --git a/RWKV-v4/train.py b/RWKV-v4/train.py index d8c6eba..24b5db7 100644 --- a/RWKV-v4/train.py +++ b/RWKV-v4/train.py @@ -65,7 +65,7 @@ os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflo os.environ['RWKV_DEEPSPEED'] = '1' # Use DeepSpeed? 0 = False, 1 = True -if int(os.environ['RWKV_NUM_GPUS']) == 1 and os.environ['RWKV_FLOAT_MODE'] == 'fp32': # the only case where DeepSpeed is worse +if int(os.environ['RWKV_NUM_GPUS']) == 1: # turn off DeepSpeed for 1 GPU training os.environ['RWKV_DEEPSPEED'] = '0' os.environ['USE_WANDB'] = '0' # wandb logging. 0 = False, 1 = True