From ef29f4b9e89ad17079c1c407abf84b234fa5ebcf Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Sun, 15 Aug 2021 21:16:41 +0800 Subject: [PATCH] fixed nan loss --- src/model.py | 19 +++---------------- src/trainer.py | 29 +++++++++++++---------------- 2 files changed, 16 insertions(+), 32 deletions(-) diff --git a/src/model.py b/src/model.py index eb2724a..328a560 100644 --- a/src/model.py +++ b/src/model.py @@ -51,6 +51,7 @@ class RWKV_TimeMix(nn.Module): v = self.value(x) r = self.receptance(x) + k = torch.clamp(k, max=30) # clamp crazy values k = torch.exp(k) sum_k = torch.cumsum(k, dim=1) @@ -261,20 +262,6 @@ class MHA_pro(nn.Module): # The GPT Model with our blocks ######################################################################################################## -class LabelSmoothingCrossEntropy(nn.Module): # can avoid nan loss - def __init__(self, smoothing=0.0): - super().__init__() - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - - def forward(self, pred, target): - pred = pred.log_softmax(dim=-1) - with torch.no_grad(): - true_dist = torch.zeros_like(pred) - true_dist.fill_(self.smoothing / (pred.size(-1) - 1)) - true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) - return torch.mean(torch.sum(-true_dist * pred, dim=-1)) - class RMSNorm(nn.Module): def __init__(self, d): super().__init__() @@ -379,7 +366,7 @@ class GPT(nn.Module): curve = curve - torch.mean(curve) + 1 # normalize mean to 1 mix_strength = 1 - 1.2 * h / (self.config.n_head - 1) # mix_strength from 1 to -0.2 ww[k][h] = (1 - mix_strength) + curve * mix_strength - # special tweak because of time_shift + # special tweaks because of time_shift ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 2] * 2 + 1) / 3 ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] + 1) / 2 ww[k][h][self.config.ctx_len - 1] = 1 @@ -450,6 +437,6 @@ class GPT(nn.Module): loss = None if targets is not None: - loss = LabelSmoothingCrossEntropy(smoothing=5e-5)(x.view(-1, x.size(-1)), targets.view(-1)) + loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1)) return x, loss diff --git a/src/trainer.py b/src/trainer.py index 1fd5123..ead5bb2 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -70,27 +70,29 @@ class Trainer: batch_size=config.batch_size, num_workers=config.num_workers) - losses = [] pbar = tqdm(enumerate(loader), total=len(loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) + for it, (x, y) in pbar: - # place data on the correct device - x = x.to(self.device) + x = x.to(self.device) # place data on the correct device y = y.to(self.device) - - # forward the model + with torch.set_grad_enabled(is_train): - logits, loss = model(x, y) - loss = loss.mean() # collapse all losses if they are scattered on multiple gpus - losses.append(loss.item()) - - if is_train: + logits, loss = model(x, y) # forward the model + loss = loss.mean() # collapse all losses if they are scattered on multiple gpus - # backprop and update the parameters + if is_train: # backprop and update the parameters model.zero_grad() loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) optimizer.step() + + # try: + # torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip, error_if_nonfinite=True) + # optimizer.step() + # except: + # pass # ignore nan sample -> sometimes can continue # decay the learning rate based on our progress if config.lr_decay: @@ -124,11 +126,6 @@ class Trainer: self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor pbar.set_description(f"epoch {epoch+1} progress {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}") - if not is_train: - test_loss = float(np.mean(losses)) - logger.info("test loss: %f", test_loss) - return test_loss - best_loss = float('inf') self.tokens = 0 # counter used for learning rate decay for epoch in range(config.max_epochs):