|
|
|
@ -25,21 +25,6 @@ else:
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
|
|
|
|
|
|
|
|
class L2Wrap(torch.autograd.Function):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
|
|
def forward(ctx, loss, y):
|
|
|
|
|
|
|
|
ctx.save_for_backward(y)
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
|
|
def backward(ctx, grad_output):
|
|
|
|
|
|
|
|
y = ctx.saved_tensors[0]
|
|
|
|
|
|
|
|
# to encourage the logits to be close to 0
|
|
|
|
|
|
|
|
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
|
|
|
|
|
|
|
maxx, ids = torch.max(y, -1, keepdim=True)
|
|
|
|
|
|
|
|
gy = torch.zeros_like(y)
|
|
|
|
|
|
|
|
gy.scatter_(-1, ids, maxx * factor)
|
|
|
|
|
|
|
|
return (grad_output, gy)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainerConfig:
|
|
|
|
class TrainerConfig:
|
|
|
|
batch_size = 64
|
|
|
|
batch_size = 64
|
|
|
|
learning_rate = 4e-4
|
|
|
|
learning_rate = 4e-4
|
|
|
|
@ -74,14 +59,13 @@ class Trainer(LightningLite):
|
|
|
|
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=m_cfg.model_type,
|
|
|
|
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=m_cfg.model_type,
|
|
|
|
n_layer=m_cfg.n_layer, n_embd=m_cfg.n_embd))
|
|
|
|
n_layer=m_cfg.n_layer, n_embd=m_cfg.n_embd))
|
|
|
|
print('[1]')
|
|
|
|
print('[1]')
|
|
|
|
model.to(self.device)
|
|
|
|
|
|
|
|
print('[2]')
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
with torch.no_grad():
|
|
|
|
if m_cfg.LOAD_MODEL:
|
|
|
|
if m_cfg.LOAD_MODEL:
|
|
|
|
print('loading', m_cfg.MODEL_NAME)
|
|
|
|
print('loading', m_cfg.MODEL_NAME)
|
|
|
|
m2 = torch.load(m_cfg.MODEL_NAME + '.pth', map_location=torch.device(self.device))
|
|
|
|
m2 = torch.load(m_cfg.MODEL_NAME + '.pth', map_location='cpu')
|
|
|
|
model.load_state_dict(m2)
|
|
|
|
model.load_state_dict(m2)
|
|
|
|
del m2
|
|
|
|
del m2
|
|
|
|
|
|
|
|
model.to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
self.model = model
|
|
|
|
self.train_dataset = train_dataset
|
|
|
|
self.train_dataset = train_dataset
|
|
|
|
@ -106,8 +90,6 @@ class Trainer(LightningLite):
|
|
|
|
raw_model = model.module if hasattr(self.model, "module") else model
|
|
|
|
raw_model = model.module if hasattr(self.model, "module") else model
|
|
|
|
optimizer = raw_model.configure_optimizers(config)
|
|
|
|
optimizer = raw_model.configure_optimizers(config)
|
|
|
|
model, optimizer = self.setup(model, optimizer)
|
|
|
|
model, optimizer = self.setup(model, optimizer)
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
print('[3]')
|
|
|
|
print('[3]')
|
|
|
|
|
|
|
|
|
|
|
|
def run_epoch(split):
|
|
|
|
def run_epoch(split):
|
|
|
|
@ -129,11 +111,12 @@ class Trainer(LightningLite):
|
|
|
|
pbar = tqdm(enumerate(loader), total=len(
|
|
|
|
pbar = tqdm(enumerate(loader), total=len(
|
|
|
|
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
|
|
|
|
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
|
|
|
|
loader = self.setup_dataloaders(loader)
|
|
|
|
loader = self.setup_dataloaders(loader)
|
|
|
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
for it, (x, y) in pbar:
|
|
|
|
for it, (x, y) in pbar:
|
|
|
|
with torch.set_grad_enabled(is_train):
|
|
|
|
with torch.set_grad_enabled(is_train):
|
|
|
|
yyy, loss = model(x, y) # forward the model
|
|
|
|
loss = model(x, y) # forward the model
|
|
|
|
lossL2 = L2Wrap.apply(loss, yyy)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.environ['RWKV_DEEPSPEED'] == '0':
|
|
|
|
if os.environ['RWKV_DEEPSPEED'] == '0':
|
|
|
|
all_loss = [loss.clone()]
|
|
|
|
all_loss = [loss.clone()]
|
|
|
|
@ -143,7 +126,7 @@ class Trainer(LightningLite):
|
|
|
|
|
|
|
|
|
|
|
|
if is_train: # backprop and update the parameters
|
|
|
|
if is_train: # backprop and update the parameters
|
|
|
|
model.zero_grad()
|
|
|
|
model.zero_grad()
|
|
|
|
self.backward(lossL2)
|
|
|
|
self.backward(loss)
|
|
|
|
|
|
|
|
|
|
|
|
# deepspeed will handle gradient_clipping
|
|
|
|
# deepspeed will handle gradient_clipping
|
|
|
|
|
|
|
|
|
|
|
|
|