improvements

main
BlinkDL 3 years ago
parent f81349f127
commit 778c0a7f58

@ -2,12 +2,13 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
######################################################################################################## ########################################################################################################
import json import json, math
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities import rank_zero_info
from .binidx import MMapIndexedDataset from .binidx import MMapIndexedDataset
from .utils import MaybeIsPrime
class MyDataset(Dataset): class MyDataset(Dataset):
@ -20,6 +21,18 @@ class MyDataset(Dataset):
print("current vocab size =", self.vocab_size, "(make sure it's correct)") print("current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data._bin_buffer) // 2 self.data_size = len(self.data._bin_buffer) // 2
print(f"data has {self.data_size} tokens.") print(f"data has {self.data_size} tokens.")
if args.my_pile_mode > 0:
assert self.data_size == 332115325534 and self.vocab_size == 50277 and args.ctx_len == 1024
self.samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
assert self.samples_per_epoch == 40320
print("########## Pile 20b-tokenized mode {args.my_pile_mode} ##########")
self.magic_prime = 324331313
dataset_slot = self.data_size // args.ctx_len
assert MaybeIsPrime(self.magic_prime)
assert self.magic_prime % 3 == 2
assert self.magic_prime / dataset_slot > 0.999999 and self.magic_prime / dataset_slot <= 1
elif args.data_type == "numpy": elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int") self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size self.vocab_size = args.vocab_size
@ -48,15 +61,29 @@ class MyDataset(Dataset):
self.itos = {i: ch for i, ch in enumerate(unique)} self.itos = {i: ch for i, ch in enumerate(unique)}
def __len__(self): def __len__(self):
return self.args.epoch_steps * int(self.args.devices) * self.args.micro_bsz return self.args.epoch_steps * self.args.micro_bsz
def __getitem__(self, idx): def __getitem__(self, idx):
# #
# we are cheating: pick a random spot in dataset # we are cheating: pick a random spot in dataset
# #
rank = self.global_rank
epoch = self.real_epoch
world_size = self.world_size
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
ctx_len = self.args.ctx_len ctx_len = self.args.ctx_len
req_len = ctx_len + 1 req_len = ctx_len + 1
i = np.random.randint(0, self.data_size - req_len)
if self.args.my_pile_mode > 0:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
factor = (math.sqrt(5) - 1) / 2
factor = int(self.magic_prime * factor)
i = ((factor * ii * ii * ii) % self.magic_prime) * ctx_len
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
else:
i = np.random.randint(0, self.data_size - req_len)
if "MMapIndexedDataset" in str(type(self.data)): if "MMapIndexedDataset" in str(type(self.data)):
dix = self.data.get(idx=0, offset=i, length=req_len).astype(int) dix = self.data.get(idx=0, offset=i, length=req_len).astype(int)
elif "numpy" in str(type(self.data)): elif "numpy" in str(type(self.data)):

@ -2,7 +2,7 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
######################################################################################################## ########################################################################################################
import os, math, gc, time import os, math, gc
from re import L from re import L
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -20,7 +20,7 @@ def __nop(ob):
MyModule = nn.Module MyModule = nn.Module
MyFunction = __nop MyFunction = __nop
if os.environ["RWKV_JIT"] == "1": if os.environ["RWKV_JIT_ON"] == "1":
MyModule = torch.jit.ScriptModule MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method MyFunction = torch.jit.script_method
@ -273,9 +273,31 @@ class RWKV(pl.LightningModule):
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
def configure_optimizers(self): def configure_optimizers(self):
lr_1x = set()
lr_2x = set()
lr_3x = set()
for n, p in self.named_parameters():
if ("time_mix" in n) and (self.args.my_pile_mode == 2):
lr_2x.add(n)
elif "time_decay" in n:
lr_2x.add(n)
elif "time_first" in n:
lr_3x.add(n)
else:
lr_1x.add(n)
lr_1x = sorted(list(lr_1x))
lr_2x = sorted(list(lr_2x))
lr_3x = sorted(list(lr_3x))
# print('1x', lr_1x)
# print('2x', lr_2x)
# print('3x', lr_3x)
param_dict = {n: p for n, p in self.named_parameters()}
optim_groups = [ optim_groups = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
] ]
if self.deepspeed_offload: if self.deepspeed_offload:
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
@ -326,25 +348,13 @@ class RWKV(pl.LightningModule):
idx, targets = batch idx, targets = batch
logits = self(idx) logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
if self.trainer.global_rank == 0:
t_now = time.time_ns()
try:
t_cost = (t_now - self.trainer.my_time_ns) / 1e9
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True)
except:
pass
self.trainer.my_time_ns = t_now
self.trainer.my_loss = loss.item()
self.trainer.my_loss_sum += self.trainer.my_loss
self.trainer.my_loss_count += 1
self.trainer.my_epoch_loss = self.trainer.my_loss_sum / self.trainer.my_loss_count
self.log("lr", self.trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", self.trainer.my_epoch_loss, prog_bar=True, on_step=True)
return L2Wrap.apply(loss, logits) return L2Wrap.apply(loss, logits)
def training_step_end(self, batch_parts):
all = self.all_gather(batch_parts)
if self.trainer.is_global_zero:
self.trainer.my_loss_all = all
def generate_init_weight(self): def generate_init_weight(self):
print( print(
f""" f"""

@ -0,0 +1,50 @@
import random
def MaybeIsPrime(number):
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
return True
else:
return False
def FermatPrimalityTest(number):
if number > 1:
for time in range(3):
randomNumber = random.randint(2, number) - 1
if pow(randomNumber, number - 1, number) != 1:
return False
return True
else:
return False
def MillerRabinPrimalityTest(number):
if number == 2:
return True
elif number == 1 or number % 2 == 0:
return False
oddPartOfNumber = number - 1
timesTwoDividNumber = 0
while oddPartOfNumber % 2 == 0:
oddPartOfNumber = oddPartOfNumber // 2
timesTwoDividNumber = timesTwoDividNumber + 1
for time in range(3):
while True:
randomNumber = random.randint(2, number) - 1
if randomNumber != 0 and randomNumber != 1:
break
randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
iterationNumber = 1
while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1):
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
iterationNumber = iterationNumber + 1
if randomNumberWithPower != (number - 1):
return False
return True

@ -4,10 +4,11 @@
if __name__ == "__main__": if __name__ == "__main__":
print("\n!!! NOTE: THIS IS STILL WIP !!!\n") print("\n!!! NOTE: THIS IS STILL WIP !!!\n")
import os, warnings, math, datetime, sys import os, warnings, math, datetime, sys, time
import numpy as np import numpy as np
from argparse import ArgumentParser from argparse import ArgumentParser
import torch import torch
from torch.utils.data import DataLoader
import deepspeed import deepspeed
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
@ -83,10 +84,12 @@ if __name__ == "__main__":
parser.add_argument("--adam_eps", default=1e-8, type=float) parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_mode", default=0, type=int) # my special pile mode
args = parser.parse_args() args = parser.parse_args()
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
args.enable_checkpointing = False args.enable_checkpointing = False
args.replace_sampler_ddp = False
args.logger = False args.logger = False
args.gradient_clip_val = 1.0 args.gradient_clip_val = 1.0
args.num_sanity_val_steps = 0 args.num_sanity_val_steps = 0
@ -95,6 +98,12 @@ if __name__ == "__main__":
args.max_epochs = -1 # continue forever args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2) args.betas = (args.beta1, args.beta2)
if args.my_pile_mode > 0:
args.epoch_steps = 40320 // (int(args.devices) * args.micro_bsz)
assert args.epoch_steps * int(args.devices) * args.micro_bsz == 40320
if args.my_pile_mode == 2:
assert args.lr_final == args.lr_init
samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len tokens_per_epoch = samples_per_epoch * args.ctx_len
rank_zero_info( rank_zero_info(
@ -138,9 +147,9 @@ if __name__ == "__main__":
if args.precision == "fp16": if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
os.environ["RWKV_JIT"] = "1" os.environ["RWKV_JIT_ON"] = "1"
if "deepspeed_stage_3" in args.strategy: if "deepspeed_stage_3" in args.strategy:
os.environ["RWKV_JIT"] = "0" os.environ["RWKV_JIT_ON"] = "0"
import torch import torch
@ -170,9 +179,37 @@ if __name__ == "__main__":
args = self.args args = self.args
g_step = trainer.global_step g_step = trainer.global_step
# logging # LR schedule
if trainer.global_rank == 0: w_step = args.warmup_steps
if g_step == 0: if g_step < w_step:
lr = args.lr_init * (g_step / w_step)
else:
if args.lr_final == args.lr_init:
lr = args.lr_init
else:
progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1)
progress = min(1, max(0, progress))
if args.lr_final == 0 or args.lr_init == 0: # linear decay
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
for param_group in trainer.optimizers[0].param_groups:
if self.args.my_pile_mode == 0:
param_group["lr"] = lr * param_group["my_lr_scale"]
elif self.args.my_pile_mode == 2:
if param_group["my_lr_scale"] > 1:
param_group["lr"] = lr * 5
else:
param_group["lr"] = lr
# print(param_group["lr"], param_group["my_lr_scale"])
trainer.my_lr = lr
# rank_zero_info(f"{g_step} {lr}")
if g_step == 0:
if trainer.is_global_zero: # logging
trainer.my_loss_sum = 0 trainer.my_loss_sum = 0
trainer.my_loss_count = 0 trainer.my_loss_count = 0
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a") trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
@ -196,39 +233,42 @@ if __name__ == "__main__":
) )
trainer.my_wandb = wandb trainer.my_wandb = wandb
# LR schedule
w_step = args.warmup_steps
if g_step < w_step:
lr = args.lr_init * (g_step / w_step)
else:
progress = (g_step - w_step) / (args.epoch_count * args.epoch_steps - w_step - 1)
progress = min(1, max(0, progress))
if args.lr_final == 0 or args.lr_init == 0: # linear decay
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
for param_group in trainer.optimizers[0].param_groups:
param_group["lr"] = lr
trainer.my_lr = lr
# rank_zero_info(f"{g_step} {lr}")
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
args = self.args args = self.args
# logging if trainer.is_global_zero: # logging
if trainer.global_rank == 0: t_now = time.time_ns()
try:
t_cost = (t_now - trainer.my_time_ns) / 1e9
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True)
except:
pass
trainer.my_time_ns = t_now
trainer.my_loss = trainer.my_loss_all.float().mean().item()
trainer.my_loss_sum += trainer.my_loss
trainer.my_loss_count += 1
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
if len(args.wandb) > 0: if len(args.wandb) > 0:
trainer.my_wandb.log( trainer.my_wandb.log(
{"loss": trainer.my_loss, "lr": trainer.my_lr}, {"loss": trainer.my_loss, "lr": trainer.my_lr},
step=trainer.global_step, step=trainer.global_step,
) )
def on_train_epoch_start(self, trainer, pl_module):
args = self.args
dataset = trainer.train_dataloader.dataset.datasets
assert "MyDataset" in str(dataset)
dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size
def on_train_epoch_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module):
args = self.args args = self.args
if trainer.global_rank == 0: if trainer.is_global_zero: # logging & save state_dict
if trainer.current_epoch % args.epoch_save == 0 or trainer.current_epoch == args.epoch_count - 1: if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
torch.save( torch.save(
pl_module.state_dict(), pl_module.state_dict(),
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
@ -251,7 +291,6 @@ if __name__ == "__main__":
######################################################################################################## ########################################################################################################
from torch.utils.data import DataLoader
from src.dataset import MyDataset from src.dataset import MyDataset
from src.model import RWKV from src.model import RWKV
@ -261,8 +300,8 @@ if __name__ == "__main__":
model = RWKV(args) model = RWKV(args)
if len(args.load_model) == 0: if len(args.load_model) == 0:
args.load_model = f"{args.proj_dir}/rwkv-init.pth" # init weights to tmp file args.load_model = f"{args.proj_dir}/rwkv-init.pth"
generate_init_weight(model, args.load_model) generate_init_weight(model, args.load_model) # save initial weights to tmp file
print(f"########## Loading {args.load_model}... ##########") print(f"########## Loading {args.load_model}... ##########")
load_dict = torch.load(args.load_model, map_location="cpu") load_dict = torch.load(args.load_model, map_location="cpu")
@ -273,5 +312,7 @@ if __name__ == "__main__":
callbacks=[train_callback(args)], callbacks=[train_callback(args)],
) )
train_loader = DataLoader(train_data, batch_size=args.micro_bsz, num_workers=1) # must set shuffle=False, persistent_workers=False (because worker is in another thread)
trainer.fit(model, train_loader) data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
trainer.fit(model, data_loader)

Loading…
Cancel
Save