From 09c76b185a6660578270c20dd5c3fb820f516b1c Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Sun, 4 Sep 2022 10:27:11 +0800 Subject: [PATCH] no message --- RWKV-v4/src/model.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/RWKV-v4/src/model.py b/RWKV-v4/src/model.py index 13dac56..7434ccb 100644 --- a/RWKV-v4/src/model.py +++ b/RWKV-v4/src/model.py @@ -8,7 +8,10 @@ import logging import torch import torch.nn as nn from torch.nn import functional as F -from deepspeed.ops.adam import FusedAdam +try: + from deepspeed.ops.adam import FusedAdam +except: + pass # some poor windows users cant install deepspeed logger = logging.getLogger(__name__) @@ -350,7 +353,11 @@ class GPT(nn.Module): for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] - optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + try: + optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + except: + print('\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n') + optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps) return optimizer