diff --git a/RWKV-v4/src/model.py b/RWKV-v4/src/model.py index 13dac56..7434ccb 100644 --- a/RWKV-v4/src/model.py +++ b/RWKV-v4/src/model.py @@ -8,7 +8,10 @@ import logging import torch import torch.nn as nn from torch.nn import functional as F -from deepspeed.ops.adam import FusedAdam +try: + from deepspeed.ops.adam import FusedAdam +except: + pass # some poor windows users cant install deepspeed logger = logging.getLogger(__name__) @@ -350,7 +353,11 @@ class GPT(nn.Module): for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] - optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + try: + optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + except: + print('\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n') + optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps) return optimizer