finetuning 1.5B model using 16G VRAM

main
BlinkDL 3 years ago
parent 23b0c74950
commit 6ab2e71c25

@ -2,7 +2,7 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os, math, gc
import os, math, gc, time
from re import L
import torch
import torch.nn as nn
@ -265,7 +265,7 @@ class RWKV(pl.LightningModule):
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
]
if self.deepspeed_offload:
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
@property
@ -310,14 +310,26 @@ class RWKV(pl.LightningModule):
return x
def training_step(self, batch, batch_idx):
args = self.args
idx, targets = batch
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
self.trainer.my_loss = loss.item()
self.trainer.my_epoch_loss = loss.item()
self.log("lr", self.trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", self.trainer.my_epoch_loss, prog_bar=True, on_step=True)
if self.trainer.global_rank == 0:
t_now = time.time_ns()
try:
t_cost = (t_now - self.trainer.my_time_ns) / 1e9
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("token/s", args.ctx_len * float(args.devices) * args.micro_bsz / t_cost, prog_bar=True, on_step=True)
except:
pass
self.trainer.my_time_ns = t_now
self.trainer.my_loss = loss.item()
self.trainer.my_loss_sum += self.trainer.my_loss
self.trainer.my_loss_count += 1
self.trainer.my_epoch_loss = self.trainer.my_loss_sum / self.trainer.my_loss_count
self.log("lr", self.trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", self.trainer.my_epoch_loss, prog_bar=True, on_step=True)
return L2Wrap.apply(loss, logits)

@ -3,8 +3,8 @@
########################################################################################################
if __name__ == "__main__":
print("\n!!! NOTE: THIS IS STILL WIP (and a bit slower than RWKV-4) !!!\n")
import os, warnings, math, datetime
print("\n!!! NOTE: THIS IS STILL WIP !!!\n")
import os, warnings, math, datetime, sys
import numpy as np
from argparse import ArgumentParser
import torch
@ -16,31 +16,60 @@ if __name__ == "__main__":
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning import Callback
seed_everything(42)
# print("WARNING: THIS IS ONLY FOR DEBUG")
# seed_everything(42)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
########################################################################################################
# example: train a simple L6-D512 RWKV from scratch
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: fine-tune RWKV 1.5B using 8xA100 40G
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
parser.add_argument("--load_model", default="", type=str)
parser.add_argument("--wandb", default="", type=str)
parser.add_argument("--wandb", default="", type=str) # wandb project name
parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int)
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has xxx steps
parser.add_argument("--epoch_count", default=500, type=int)
parser.add_argument("--epoch_begin", default=0, type=int)
parser.add_argument("--epoch_save", default=5, type=int)
parser.add_argument("--micro_bsz", default=12, type=int)
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--pre_ffn", default=0, type=int)
@ -53,17 +82,15 @@ if __name__ == "__main__":
parser.add_argument("--beta2", default=0.99, type=float)
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int)
parser.add_argument("--data_workers", default=1, type=int)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
args = parser.parse_args()
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
args.enable_checkpointing = False
args.logger = False
args.gradient_clip_val = 1.0
args.num_sanity_val_steps = 0
args.check_val_every_n_epoch = int(1e20)
args.auto_select_gpus = True
args.log_every_n_steps = int(1e20)
args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2)
@ -74,7 +101,7 @@ if __name__ == "__main__":
f"""
############################################################################
#
# RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()} {args.strategy.upper()} {'with grad_cp' if args.grad_cp > 0 else ''}
# RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
#
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
#
@ -86,9 +113,9 @@ if __name__ == "__main__":
#
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, β {args.betas}, eps {args.adam_eps}
#
# torch {torch.__version__}, recommend 1.12.1+cu116 or newer
# deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer
# pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
# Found deepspeed {deepspeed.__version__}, recommend 0.7.2 or newer
# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
#
############################################################################
"""
@ -102,7 +129,7 @@ if __name__ == "__main__":
assert len(args.data_file) > 0
if args.lr_final == 0 or args.lr_init == 0:
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule.\n\n")
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision
@ -142,8 +169,15 @@ if __name__ == "__main__":
# logging
if trainer.global_rank == 0:
if g_step == 0:
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
trainer.my_log.write(f"NEW RUN {datetime.datetime.now()}\n{vars(self.args)}\n")
trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
try:
print(f"\n{trainer.strategy.config}\n")
trainer.my_log.write(f"{trainer.strategy.config}\n")
except:
pass
trainer.my_log.flush()
if len(args.wandb) > 0:
print("Login to wandb...")
@ -152,7 +186,7 @@ if __name__ == "__main__":
model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd)
wandb.init(
project=args.wandb,
name=model_name + "-" + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"),
name=model_name + "-" + args.my_timestamp,
config=args,
save_code=False,
)
@ -198,6 +232,9 @@ if __name__ == "__main__":
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
trainer.my_log.flush()
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
@rank_zero_only
def generate_init_weight(model, temp_name):
try:
@ -232,5 +269,5 @@ if __name__ == "__main__":
callbacks=[train_callback(args)],
)
train_loader = DataLoader(train_data, batch_size=args.micro_bsz, num_workers=args.data_workers)
train_loader = DataLoader(train_data, batch_size=args.micro_bsz, num_workers=1)
trainer.fit(model, train_loader)

Loading…
Cancel
Save