main
BlinkDL 3 years ago
parent dc7e0802d0
commit 2815260d83

@ -15,6 +15,21 @@ logger = logging.getLogger(__name__)
RWKV_HEAD_QK_DIM = 0 RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
class L2Wrap(torch.autograd.Function):
@staticmethod
def forward(ctx, loss, y):
ctx.save_for_backward(y)
return loss
@staticmethod
def backward(ctx, grad_output):
y = ctx.saved_tensors[0]
# to encourage the logits to be close to 0
factor = 1e-4 / (y.shape[0] * y.shape[1])
maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y)
gy.scatter_(-1, ids, maxx * factor)
return (grad_output, gy)
######################################################################################################## ########################################################################################################
# CUDA Kernel # CUDA Kernel
######################################################################################################## ########################################################################################################
@ -371,4 +386,4 @@ class GPT(nn.Module):
if targets is not None: if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1)) loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1))
return x, loss return L2Wrap.apply(loss, x)

@ -25,21 +25,6 @@ else:
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
class L2Wrap(torch.autograd.Function):
@staticmethod
def forward(ctx, loss, y):
ctx.save_for_backward(y)
return loss
@staticmethod
def backward(ctx, grad_output):
y = ctx.saved_tensors[0]
# to encourage the logits to be close to 0
factor = 1e-4 / (y.shape[0] * y.shape[1])
maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y)
gy.scatter_(-1, ids, maxx * factor)
return (grad_output, gy)
class TrainerConfig: class TrainerConfig:
batch_size = 64 batch_size = 64
learning_rate = 4e-4 learning_rate = 4e-4
@ -74,14 +59,13 @@ class Trainer(LightningLite):
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=m_cfg.model_type, model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=m_cfg.model_type,
n_layer=m_cfg.n_layer, n_embd=m_cfg.n_embd)) n_layer=m_cfg.n_layer, n_embd=m_cfg.n_embd))
print('[1]') print('[1]')
model.to(self.device)
print('[2]')
with torch.no_grad(): with torch.no_grad():
if m_cfg.LOAD_MODEL: if m_cfg.LOAD_MODEL:
print('loading', m_cfg.MODEL_NAME) print('loading', m_cfg.MODEL_NAME)
m2 = torch.load(m_cfg.MODEL_NAME + '.pth', map_location=torch.device(self.device)) m2 = torch.load(m_cfg.MODEL_NAME + '.pth', map_location='cpu')
model.load_state_dict(m2) model.load_state_dict(m2)
del m2 del m2
model.to(self.device)
self.model = model self.model = model
self.train_dataset = train_dataset self.train_dataset = train_dataset
@ -106,8 +90,6 @@ class Trainer(LightningLite):
raw_model = model.module if hasattr(self.model, "module") else model raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config) optimizer = raw_model.configure_optimizers(config)
model, optimizer = self.setup(model, optimizer) model, optimizer = self.setup(model, optimizer)
gc.collect()
torch.cuda.empty_cache()
print('[3]') print('[3]')
def run_epoch(split): def run_epoch(split):
@ -129,11 +111,12 @@ class Trainer(LightningLite):
pbar = tqdm(enumerate(loader), total=len( pbar = tqdm(enumerate(loader), total=len(
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
loader = self.setup_dataloaders(loader) loader = self.setup_dataloaders(loader)
gc.collect()
torch.cuda.empty_cache()
for it, (x, y) in pbar: for it, (x, y) in pbar:
with torch.set_grad_enabled(is_train): with torch.set_grad_enabled(is_train):
yyy, loss = model(x, y) # forward the model loss = model(x, y) # forward the model
lossL2 = L2Wrap.apply(loss, yyy)
if os.environ['RWKV_DEEPSPEED'] == '0': if os.environ['RWKV_DEEPSPEED'] == '0':
all_loss = [loss.clone()] all_loss = [loss.clone()]
@ -143,7 +126,7 @@ class Trainer(LightningLite):
if is_train: # backprop and update the parameters if is_train: # backprop and update the parameters
model.zero_grad() model.zero_grad()
self.backward(lossL2) self.backward(loss)
# deepspeed will handle gradient_clipping # deepspeed will handle gradient_clipping

Loading…
Cancel
Save