no message

main
BlinkDL 3 years ago
parent c49fd38ba1
commit 09c76b185a

@ -8,7 +8,10 @@ import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
try:
from deepspeed.ops.adam import FusedAdam from deepspeed.ops.adam import FusedAdam
except:
pass # some poor windows users cant install deepspeed
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -350,7 +353,11 @@ class GPT(nn.Module):
for pn in sorted(list(no_decay))], "weight_decay": 0.0}, for pn in sorted(list(no_decay))], "weight_decay": 0.0},
] ]
try:
optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
except:
print('\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n')
optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
return optimizer return optimizer

Loading…
Cancel
Save