|
|
|
|
@ -8,7 +8,10 @@ import logging
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
from deepspeed.ops.adam import FusedAdam
|
|
|
|
|
try:
|
|
|
|
|
from deepspeed.ops.adam import FusedAdam
|
|
|
|
|
except:
|
|
|
|
|
pass # some poor windows users cant install deepspeed
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
@ -350,7 +353,11 @@ class GPT(nn.Module):
|
|
|
|
|
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
|
|
|
|
try:
|
|
|
|
|
optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
|
|
|
|
except:
|
|
|
|
|
print('\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n')
|
|
|
|
|
optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
|
|
|
|
|
|
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
|
|
|