finetuning 1.5B model using 16G VRAM

main
BlinkDL 3 years ago
parent 23b0c74950
commit 6ab2e71c25

@ -2,7 +2,7 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
######################################################################################################## ########################################################################################################
import os, math, gc import os, math, gc, time
from re import L from re import L
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -265,7 +265,7 @@ class RWKV(pl.LightningModule):
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
] ]
if self.deepspeed_offload: if self.deepspeed_offload:
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
@property @property
@ -310,12 +310,24 @@ class RWKV(pl.LightningModule):
return x return x
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
args = self.args
idx, targets = batch idx, targets = batch
logits = self(idx) logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
if self.trainer.global_rank == 0:
t_now = time.time_ns()
try:
t_cost = (t_now - self.trainer.my_time_ns) / 1e9
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True)
except:
pass
self.trainer.my_time_ns = t_now
self.trainer.my_loss = loss.item() self.trainer.my_loss = loss.item()
self.trainer.my_epoch_loss = loss.item() self.trainer.my_loss_sum += self.trainer.my_loss
self.trainer.my_loss_count += 1
self.trainer.my_epoch_loss = self.trainer.my_loss_sum / self.trainer.my_loss_count
self.log("lr", self.trainer.my_lr, prog_bar=True, on_step=True) self.log("lr", self.trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", self.trainer.my_epoch_loss, prog_bar=True, on_step=True) self.log("loss", self.trainer.my_epoch_loss, prog_bar=True, on_step=True)

@ -3,8 +3,8 @@
######################################################################################################## ########################################################################################################
if __name__ == "__main__": if __name__ == "__main__":
print("\n!!! NOTE: THIS IS STILL WIP (and a bit slower than RWKV-4) !!!\n") print("\n!!! NOTE: THIS IS STILL WIP !!!\n")
import os, warnings, math, datetime import os, warnings, math, datetime, sys
import numpy as np import numpy as np
from argparse import ArgumentParser from argparse import ArgumentParser
import torch import torch
@ -16,31 +16,60 @@ if __name__ == "__main__":
from pytorch_lightning.callbacks import TQDMProgressBar from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning import Callback from pytorch_lightning import Callback
seed_everything(42) # print("WARNING: THIS IS ONLY FOR DEBUG")
# seed_everything(42)
np.set_printoptions(precision=4, suppress=True, linewidth=200) np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
######################################################################################################## ########################################################################################################
# example: train a simple L6-D512 RWKV from scratch
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: fine-tune RWKV 1.5B using 8xA100 40G
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
parser = ArgumentParser() parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser) parser = Trainer.add_argparse_args(parser)
parser.add_argument("--load_model", default="", type=str) parser.add_argument("--load_model", default="", type=str)
parser.add_argument("--wandb", default="", type=str) parser.add_argument("--wandb", default="", type=str) # wandb project name
parser.add_argument("--proj_dir", default="out", type=str) parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--data_file", default="", type=str) parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str) parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int) parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument("--ctx_len", default=1024, type=int) parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int) parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has xxx steps
parser.add_argument("--epoch_count", default=500, type=int) parser.add_argument("--epoch_count", default=500, type=int)
parser.add_argument("--epoch_begin", default=0, type=int) parser.add_argument("--epoch_begin", default=0, type=int)
parser.add_argument("--epoch_save", default=5, type=int) parser.add_argument("--epoch_save", default=5, type=int)
parser.add_argument("--micro_bsz", default=12, type=int) parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int) parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int) parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) parser.add_argument("--pre_ffn", default=0, type=int)
@ -53,17 +82,15 @@ if __name__ == "__main__":
parser.add_argument("--beta2", default=0.99, type=float) parser.add_argument("--beta2", default=0.99, type=float)
parser.add_argument("--adam_eps", default=1e-8, type=float) parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--data_workers", default=1, type=int)
args = parser.parse_args() args = parser.parse_args()
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
args.enable_checkpointing = False args.enable_checkpointing = False
args.logger = False args.logger = False
args.gradient_clip_val = 1.0 args.gradient_clip_val = 1.0
args.num_sanity_val_steps = 0 args.num_sanity_val_steps = 0
args.check_val_every_n_epoch = int(1e20) args.check_val_every_n_epoch = int(1e20)
args.auto_select_gpus = True
args.log_every_n_steps = int(1e20) args.log_every_n_steps = int(1e20)
args.max_epochs = -1 # continue forever args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2) args.betas = (args.beta1, args.beta2)
@ -74,7 +101,7 @@ if __name__ == "__main__":
f""" f"""
############################################################################ ############################################################################
# #
# RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()} {args.strategy.upper()} {'with grad_cp' if args.grad_cp > 0 else ''} # RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
# #
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} # Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
# #
@ -86,9 +113,9 @@ if __name__ == "__main__":
# #
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, β {args.betas}, eps {args.adam_eps} # Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, β {args.betas}, eps {args.adam_eps}
# #
# torch {torch.__version__}, recommend 1.12.1+cu116 or newer # Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
# deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer # Found deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer
# pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer # Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
# #
############################################################################ ############################################################################
""" """
@ -102,7 +129,7 @@ if __name__ == "__main__":
assert len(args.data_file) > 0 assert len(args.data_file) > 0
if args.lr_final == 0 or args.lr_init == 0: if args.lr_final == 0 or args.lr_init == 0:
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule.\n\n") rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
assert args.precision in ["fp32", "tf32", "fp16", "bf16"] assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision os.environ["RWKV_FLOAT_MODE"] = args.precision
@ -142,8 +169,15 @@ if __name__ == "__main__":
# logging # logging
if trainer.global_rank == 0: if trainer.global_rank == 0:
if g_step == 0: if g_step == 0:
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a") trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
trainer.my_log.write(f"NEW RUN {datetime.datetime.now()}\n{vars(self.args)}\n") trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
try:
print(f"\n{trainer.strategy.config}\n")
trainer.my_log.write(f"{trainer.strategy.config}\n")
except:
pass
trainer.my_log.flush() trainer.my_log.flush()
if len(args.wandb) > 0: if len(args.wandb) > 0:
print("Login to wandb...") print("Login to wandb...")
@ -152,7 +186,7 @@ if __name__ == "__main__":
model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd) model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd)
wandb.init( wandb.init(
project=args.wandb, project=args.wandb,
name=model_name + "-" + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"), name=model_name + "-" + args.my_timestamp,
config=args, config=args,
save_code=False, save_code=False,
) )
@ -198,6 +232,9 @@ if __name__ == "__main__":
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
trainer.my_log.flush() trainer.my_log.flush()
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
@rank_zero_only @rank_zero_only
def generate_init_weight(model, temp_name): def generate_init_weight(model, temp_name):
try: try:
@ -232,5 +269,5 @@ if __name__ == "__main__":
callbacks=[train_callback(args)], callbacks=[train_callback(args)],
) )
train_loader = DataLoader(train_data, batch_size=args.micro_bsz, num_workers=args.data_workers) train_loader = DataLoader(train_data, batch_size=args.micro_bsz, num_workers=1)
trainer.fit(model, train_loader) trainer.fit(model, train_loader)

Loading…
Cancel
Save