fixed nan loss

main
BlinkDL 4 years ago
parent 4fd8716976
commit ef29f4b9e8

@ -51,6 +51,7 @@ class RWKV_TimeMix(nn.Module):
v = self.value(x)
r = self.receptance(x)
k = torch.clamp(k, max=30) # clamp crazy values
k = torch.exp(k)
sum_k = torch.cumsum(k, dim=1)
@ -261,20 +262,6 @@ class MHA_pro(nn.Module):
# The GPT Model with our blocks
########################################################################################################
class LabelSmoothingCrossEntropy(nn.Module): # can avoid nan loss
def __init__(self, smoothing=0.0):
super().__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
def forward(self, pred, target):
pred = pred.log_softmax(dim=-1)
with torch.no_grad():
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (pred.size(-1) - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * pred, dim=-1))
class RMSNorm(nn.Module):
def __init__(self, d):
super().__init__()
@ -379,7 +366,7 @@ class GPT(nn.Module):
curve = curve - torch.mean(curve) + 1 # normalize mean to 1
mix_strength = 1 - 1.2 * h / (self.config.n_head - 1) # mix_strength from 1 to -0.2
ww[k][h] = (1 - mix_strength) + curve * mix_strength
# special tweak because of time_shift
# special tweaks because of time_shift
ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 2] * 2 + 1) / 3
ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] + 1) / 2
ww[k][h][self.config.ctx_len - 1] = 1
@ -450,6 +437,6 @@ class GPT(nn.Module):
loss = None
if targets is not None:
loss = LabelSmoothingCrossEntropy(smoothing=5e-5)(x.view(-1, x.size(-1)), targets.view(-1))
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
return x, loss

@ -70,28 +70,30 @@ class Trainer:
batch_size=config.batch_size,
num_workers=config.num_workers)
losses = []
pbar = tqdm(enumerate(loader), total=len(loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
for it, (x, y) in pbar:
# place data on the correct device
x = x.to(self.device)
x = x.to(self.device) # place data on the correct device
y = y.to(self.device)
# forward the model
with torch.set_grad_enabled(is_train):
logits, loss = model(x, y)
logits, loss = model(x, y) # forward the model
loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
losses.append(loss.item())
if is_train:
# backprop and update the parameters
if is_train: # backprop and update the parameters
model.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
optimizer.step()
# try:
# torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip, error_if_nonfinite=True)
# optimizer.step()
# except:
# pass # ignore nan sample -> sometimes can continue
# decay the learning rate based on our progress
if config.lr_decay:
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
@ -124,11 +126,6 @@ class Trainer:
self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor
pbar.set_description(f"epoch {epoch+1} progress {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
if not is_train:
test_loss = float(np.mean(losses))
logger.info("test loss: %f", test_loss)
return test_loss
best_loss = float('inf')
self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):

Loading…
Cancel
Save