diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 45c25e4..1a95ed6 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -2,7 +2,7 @@ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## -import json, math +import json, math, random import numpy as np import torch from torch.utils.data import Dataset @@ -37,6 +37,19 @@ class MyDataset(Dataset): print("Current vocab size =", self.vocab_size, "(make sure it's correct)") self.data_size = len(self.data) print(f"Data has {self.data_size} tokens.") + elif args.data_type == "wds_img": + def identity(x): + return x + import torchvision as vision + import webdataset as wds + import torchvision.transforms as transforms + img_transform = transforms.Compose( + [transforms.CenterCrop(256)] + ) + self.data = iter(wds.WebDataset(args.data_file, resampled=True).shuffle(1000).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity).with_epoch(1000000)) + print("WebDataset loaded.") + self.vocab_size = -1 + self.data_size = -1 else: if args.data_type == "dummy": print("Building dummy data...") @@ -71,35 +84,38 @@ class MyDataset(Dataset): return self.args.epoch_steps * self.args.micro_bsz def __getitem__(self, idx): - # - # we are cheating: pick a random spot in dataset - # args = self.args rank = self.global_rank epoch = self.real_epoch world_size = self.world_size # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}") - ctx_len = args.ctx_len - req_len = ctx_len + 1 - - if args.my_pile_stage > 0: - ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank - factor = (math.sqrt(5) - 1) / 2 - factor = int(args.magic_prime * factor) - i = ((factor * ii * ii * ii) % args.magic_prime) * ctx_len - i = i + args.my_pile_shift - # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") + if args.data_type == "wds_img": + dd = next(self.data) # jpg, json, txt + # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}") + return dd[0], dd[2] else: - i = np.random.randint(0, self.data_size - req_len) + ctx_len = args.ctx_len + req_len = ctx_len + 1 - if args.data_type == "binidx": - dix = self.data.get(idx=0, offset=i, length=req_len).astype(int) - elif args.data_type == "numpy": - dix = self.data[i : i + req_len] - else: - dix = [self.stoi[s] for s in self.data[i : i + req_len]] + if args.my_pile_stage > 0: + ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank + factor = (math.sqrt(5) - 1) / 2 + factor = int(args.magic_prime * factor) + i = ((factor * ii * ii * ii) % args.magic_prime) * ctx_len + i = i + args.my_pile_shift + # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") + else: + # cheat: pick a random spot in dataset + i = np.random.randint(0, self.data_size - req_len) + + if args.data_type == "binidx": + dix = self.data.get(idx=0, offset=i, length=req_len).astype(int) + elif args.data_type == "numpy": + dix = self.data[i : i + req_len] + else: + dix = [self.stoi[s] for s in self.data[i : i + req_len]] - x = torch.tensor(dix[:-1], dtype=torch.long) - y = torch.tensor(dix[1:], dtype=torch.long) - return x, y + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) + return x, y diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index fc2fc5f..05d97bb 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -3,7 +3,6 @@ ######################################################################################################## import os, math, gc -from re import L import torch import torch.nn as nn from torch.nn import functional as F diff --git a/RWKV-v4neo/src/model_img.py b/RWKV-v4neo/src/model_img.py new file mode 100644 index 0000000..d057bb3 --- /dev/null +++ b/RWKV-v4neo/src/model_img.py @@ -0,0 +1,195 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import os, math, gc +import torchvision as vision +import torch +import torch.nn as nn +from torch.nn import functional as F +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only +from pytorch_lightning.strategies import DeepSpeedStrategy +import deepspeed +from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam +from pytorch_msssim import SSIM + +class To2Bin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return torch.floor(x + torch.empty_like(x).uniform_(0, 1)) + @staticmethod + def backward(ctx, grad_output): + return grad_output.clone() + +def __nop(ob): + return ob + +MyModule = nn.Module +MyFunction = __nop +if os.environ["RWKV_JIT_ON"] == "1": + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method + +class RWKV_IMG(pl.LightningModule): + def __init__(self, args): + super().__init__() + self.args = args + + self.e0b0 = nn.BatchNorm2d(12) + self.e0w0 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) + self.e0b1 = nn.BatchNorm2d(12) + self.e0w1 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) + self.e0b2 = nn.BatchNorm2d(12) + self.e0w2 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) + self.e0b3 = nn.BatchNorm2d(12) + self.e0w3 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) + + self.e1b0 = nn.BatchNorm2d(48) + self.e1w0 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) + self.e1b1 = nn.BatchNorm2d(48) + self.e1w1 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) + self.e1b2 = nn.BatchNorm2d(48) + self.e1w2 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) + self.e1b3 = nn.BatchNorm2d(48) + self.e1w3 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) + + self.e2b0 = nn.BatchNorm2d(192) + self.e2w0 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) + self.e2b1 = nn.BatchNorm2d(192) + self.e2w1 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) + self.e2b2 = nn.BatchNorm2d(192) + self.e2w2 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) + self.e2b3 = nn.BatchNorm2d(192) + self.e2w3 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) + + self.ewww = nn.Conv2d(192, 8, kernel_size = 3, stride = 1, padding = 1) + + self.dwww = nn.Conv2d(8, 192, kernel_size = 3, stride = 1, padding = 1) + + self.d0b0 = nn.BatchNorm2d(192) + self.d0w0 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) + self.d0b1 = nn.BatchNorm2d(192) + self.d0w1 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) + self.d0b2 = nn.BatchNorm2d(192) + self.d0w2 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) + self.d0b3 = nn.BatchNorm2d(192) + self.d0w3 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) + + self.d1b0 = nn.BatchNorm2d(48) + self.d1w0 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) + self.d1b1 = nn.BatchNorm2d(48) + self.d1w1 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) + self.d1b2 = nn.BatchNorm2d(48) + self.d1w2 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) + self.d1b3 = nn.BatchNorm2d(48) + self.d1w3 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) + + self.d2b0 = nn.BatchNorm2d(12) + self.d2w0 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) + self.d2b1 = nn.BatchNorm2d(12) + self.d2w1 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) + self.d2b2 = nn.BatchNorm2d(12) + self.d2w2 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) + self.d2b3 = nn.BatchNorm2d(12) + self.d2w3 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) + + self.SSIM = SSIM(data_range=1, size_average=True, channel=3) + + def configure_optimizers(self): + args = self.args + optim_groups = [ + {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, + ] + if self.deepspeed_offload: + return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) + return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + + @property + def deepspeed_offload(self) -> bool: + strategy = self.trainer.strategy + if isinstance(strategy, DeepSpeedStrategy): + config = strategy.config["zero_optimization"] + return config.get("offload_optimizer") or config.get("offload_param") + return False + + def forward(self, img): + x = img + + x = F.pixel_unshuffle(x, 2) + x = x + self.e0w1(F.mish(self.e0b1(self.e0w0(F.mish(self.e0b0(x)))))) + x = x + self.e0w3(F.mish(self.e0b3(self.e0w2(F.mish(self.e0b2(x)))))) + + x = F.pixel_unshuffle(x, 2) + x = x + self.e1w1(F.mish(self.e1b1(self.e1w0(F.mish(self.e1b0(x)))))) + x = x + self.e1w3(F.mish(self.e1b3(self.e1w2(F.mish(self.e1b2(x)))))) + + x = F.pixel_unshuffle(x, 2) + x = x + self.e2w1(F.mish(self.e2b1(self.e2w0(F.mish(self.e2b0(x)))))) + x = x + self.e2w3(F.mish(self.e2b3(self.e2w2(F.mish(self.e2b2(x)))))) + + x = self.ewww(x) + + x = To2Bin.apply(torch.sigmoid(x)) + # print(x.shape, x) + + x = self.dwww(x) + + x = x + self.d0w1(F.mish(self.d0b1(self.d0w0(F.mish(self.d0b0(x)))))) + x = x + self.d0w3(F.mish(self.d0b3(self.d0w2(F.mish(self.d0b2(x)))))) + x = F.pixel_shuffle(x, 2) + + x = x + self.d1w1(F.mish(self.d1b1(self.d1w0(F.mish(self.d1b0(x)))))) + x = x + self.d1w3(F.mish(self.d1b3(self.d1w2(F.mish(self.d1b2(x)))))) + x = F.pixel_shuffle(x, 2) + + x = x + self.d2w1(F.mish(self.d2b1(self.d2w0(F.mish(self.d2b0(x)))))) + x = x + self.d2w3(F.mish(self.d2b3(self.d2w2(F.mish(self.d2b2(x)))))) + x = F.pixel_shuffle(x, 2) + + x = torch.sigmoid(x) + return x + + def training_step(self, batch, batch_idx): + args = self.args + img, txt = batch + out = self(img) + if self.trainer.is_global_zero: + if (self.trainer.global_step+1) % (100 * int(args.devices)) == 0: + vision.utils.save_image(img[:4], f"test/image_model/{self.trainer.global_step}-src.jpg") + vision.utils.save_image(out[:4], f"test/image_model/{self.trainer.global_step}-out.jpg") + + return 1 - self.SSIM(out.float(), img.float()) + + def training_step_end(self, batch_parts): + all = self.all_gather(batch_parts) + if self.trainer.is_global_zero: + self.trainer.my_loss_all = all + + def generate_init_weight(self): + print( + f""" +############################################################################ +# +# Init model weight (slow for large models)... +# +############################################################################ +""" + ) + m = {} + for n in self.state_dict(): + p = self.state_dict()[n] + shape = p.shape + + m[n] = p + + m[n] = m[n].cpu() + if os.environ["RWKV_FLOAT_MODE"] == "fp16": + m[n] = m[n].half() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + m[n] = m[n].bfloat16() + + gc.collect() + torch.cuda.empty_cache() + return m diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 5b31797..7daacb3 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -158,7 +158,7 @@ if __name__ == "__main__": if args.my_pile_stage == 2: args.warmup_steps = 10 else: - args.warmup_steps = 50 + args.warmup_steps = 30 args.epoch_begin = max_p + 1 samples_per_epoch = args.epoch_steps * args.real_bsz @@ -188,7 +188,7 @@ if __name__ == "__main__": ) rank_zero_info(str(vars(args)) + "\n") - assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy"] + assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img"] if args.lr_final == 0 or args.lr_init == 0: rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n") @@ -223,12 +223,16 @@ if __name__ == "__main__": from src.trainer import train_callback, generate_init_weight from src.dataset import MyDataset - from src.model import RWKV train_data = MyDataset(args) args.vocab_size = train_data.vocab_size - model = RWKV(args) + if args.data_type == 'wds_img': + from src.model_img import RWKV_IMG + model = RWKV_IMG(args) + else: + from src.model import RWKV + model = RWKV(args) if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights? init_weight_name = f"{args.proj_dir}/rwkv-init.pth" @@ -250,6 +254,10 @@ if __name__ == "__main__": print(f"Trying {args.load_model}") load_dict = torch.load(args.load_model, map_location="cpu") + # load_keys = load_dict.keys() + # for k in model.state_dict(): + # if k not in load_keys: + # load_dict[k] = model.state_dict()[k] model.load_state_dict(load_dict) trainer = Trainer.from_argparse_args(