From 2815260d8319a0948b3f36d3f3849941a32a7f8c Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Fri, 2 Sep 2022 17:53:39 +0800 Subject: [PATCH] better --- RWKV-v4/src/model.py | 17 ++++++++++++++++- RWKV-v4/src/trainer.py | 29 ++++++----------------------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/RWKV-v4/src/model.py b/RWKV-v4/src/model.py index 47c201f..13dac56 100644 --- a/RWKV-v4/src/model.py +++ b/RWKV-v4/src/model.py @@ -15,6 +15,21 @@ logger = logging.getLogger(__name__) RWKV_HEAD_QK_DIM = 0 print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') +class L2Wrap(torch.autograd.Function): + @staticmethod + def forward(ctx, loss, y): + ctx.save_for_backward(y) + return loss + @staticmethod + def backward(ctx, grad_output): + y = ctx.saved_tensors[0] + # to encourage the logits to be close to 0 + factor = 1e-4 / (y.shape[0] * y.shape[1]) + maxx, ids = torch.max(y, -1, keepdim=True) + gy = torch.zeros_like(y) + gy.scatter_(-1, ids, maxx * factor) + return (grad_output, gy) + ######################################################################################################## # CUDA Kernel ######################################################################################################## @@ -371,4 +386,4 @@ class GPT(nn.Module): if targets is not None: loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1)) - return x, loss + return L2Wrap.apply(loss, x) diff --git a/RWKV-v4/src/trainer.py b/RWKV-v4/src/trainer.py index 645cc6c..8025cd5 100644 --- a/RWKV-v4/src/trainer.py +++ b/RWKV-v4/src/trainer.py @@ -25,21 +25,6 @@ else: torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True -class L2Wrap(torch.autograd.Function): - @staticmethod - def forward(ctx, loss, y): - ctx.save_for_backward(y) - return loss - @staticmethod - def backward(ctx, grad_output): - y = ctx.saved_tensors[0] - # to encourage the logits to be close to 0 - factor = 1e-4 / (y.shape[0] * y.shape[1]) - maxx, ids = torch.max(y, -1, keepdim=True) - gy = torch.zeros_like(y) - gy.scatter_(-1, ids, maxx * factor) - return (grad_output, gy) - class TrainerConfig: batch_size = 64 learning_rate = 4e-4 @@ -74,14 +59,13 @@ class Trainer(LightningLite): model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=m_cfg.model_type, n_layer=m_cfg.n_layer, n_embd=m_cfg.n_embd)) print('[1]') - model.to(self.device) - print('[2]') with torch.no_grad(): if m_cfg.LOAD_MODEL: print('loading', m_cfg.MODEL_NAME) - m2 = torch.load(m_cfg.MODEL_NAME + '.pth', map_location=torch.device(self.device)) + m2 = torch.load(m_cfg.MODEL_NAME + '.pth', map_location='cpu') model.load_state_dict(m2) del m2 + model.to(self.device) self.model = model self.train_dataset = train_dataset @@ -106,8 +90,6 @@ class Trainer(LightningLite): raw_model = model.module if hasattr(self.model, "module") else model optimizer = raw_model.configure_optimizers(config) model, optimizer = self.setup(model, optimizer) - gc.collect() - torch.cuda.empty_cache() print('[3]') def run_epoch(split): @@ -129,11 +111,12 @@ class Trainer(LightningLite): pbar = tqdm(enumerate(loader), total=len( loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) loader = self.setup_dataloaders(loader) + gc.collect() + torch.cuda.empty_cache() for it, (x, y) in pbar: with torch.set_grad_enabled(is_train): - yyy, loss = model(x, y) # forward the model - lossL2 = L2Wrap.apply(loss, yyy) + loss = model(x, y) # forward the model if os.environ['RWKV_DEEPSPEED'] == '0': all_loss = [loss.clone()] @@ -143,7 +126,7 @@ class Trainer(LightningLite): if is_train: # backprop and update the parameters model.zero_grad() - self.backward(lossL2) + self.backward(loss) # deepspeed will handle gradient_clipping