diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 1a95ed6..c42b0be 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -38,18 +38,9 @@ class MyDataset(Dataset): self.data_size = len(self.data) print(f"Data has {self.data_size} tokens.") elif args.data_type == "wds_img": - def identity(x): - return x - import torchvision as vision - import webdataset as wds - import torchvision.transforms as transforms - img_transform = transforms.Compose( - [transforms.CenterCrop(256)] - ) - self.data = iter(wds.WebDataset(args.data_file, resampled=True).shuffle(1000).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity).with_epoch(1000000)) - print("WebDataset loaded.") self.vocab_size = -1 self.data_size = -1 + self.data = None else: if args.data_type == "dummy": print("Building dummy data...") @@ -90,9 +81,29 @@ class MyDataset(Dataset): world_size = self.world_size # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}") - if args.data_type == "wds_img": + if args.data_type == "wds_img": + if self.data == None: + def identity(x): + return x + import webdataset as wds + import torchvision.transforms as transforms + img_transform = transforms.Compose( + [transforms.CenterCrop(256)] + ) + self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) + for pp in self.data_raw.pipeline: + if 'Resampled' in str(pp): + pp.deterministic = True + def worker_seed(): + return rank*100000+epoch + pp.worker_seed = worker_seed + self.data = iter(self.data_raw) + # print(f"WebDataset loaded for rank {rank} epoch {epoch}") + dd = next(self.data) # jpg, json, txt # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}") + # with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp: + # tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n") return dd[0], dd[2] else: ctx_len = args.ctx_len diff --git a/RWKV-v4neo/src/model_img.py b/RWKV-v4neo/src/model_img.py index d057bb3..a451921 100644 --- a/RWKV-v4neo/src/model_img.py +++ b/RWKV-v4neo/src/model_img.py @@ -2,6 +2,138 @@ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## +import numpy as np +import os +import torch +from torchvision import models +import torch.nn as nn +import torch.nn.functional as F + + +class L2pooling(nn.Module): + def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0): + super(L2pooling, self).__init__() + self.padding = (filter_size - 2) // 2 + self.stride = stride + self.channels = channels + a = np.hanning(filter_size)[1:-1] + g = torch.Tensor(a[:, None] * a[None, :]) + g = g / torch.sum(g) + self.register_buffer( + "filter", g[None, None, :, :].repeat((self.channels, 1, 1, 1)) + ) + + def forward(self, input): + input = input**2 + out = F.conv2d( + input, + self.filter, + stride=self.stride, + padding=self.padding, + groups=input.shape[1], + ) + return (out + 1e-12).sqrt() + + +class DISTS(torch.nn.Module): + def __init__(self, load_weights=True): + super(DISTS, self).__init__() + vgg_pretrained_features = models.vgg16( + weights="VGG16_Weights.IMAGENET1K_V1" + ).features + self.stage1 = torch.nn.Sequential() + self.stage2 = torch.nn.Sequential() + self.stage3 = torch.nn.Sequential() + self.stage4 = torch.nn.Sequential() + self.stage5 = torch.nn.Sequential() + for x in range(0, 4): + self.stage1.add_module(str(x), vgg_pretrained_features[x]) + self.stage2.add_module(str(4), L2pooling(channels=64)) + for x in range(5, 9): + self.stage2.add_module(str(x), vgg_pretrained_features[x]) + self.stage3.add_module(str(9), L2pooling(channels=128)) + for x in range(10, 16): + self.stage3.add_module(str(x), vgg_pretrained_features[x]) + self.stage4.add_module(str(16), L2pooling(channels=256)) + for x in range(17, 23): + self.stage4.add_module(str(x), vgg_pretrained_features[x]) + self.stage5.add_module(str(23), L2pooling(channels=512)) + for x in range(24, 30): + self.stage5.add_module(str(x), vgg_pretrained_features[x]) + + self.register_buffer( + "mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1) + ) + self.register_buffer( + "std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1) + ) + + self.chns = [3, 64, 128, 256, 512, 512] + self.register_buffer( + "alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)) + ) + self.register_buffer("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1))) + self.alpha.data.normal_(0.1, 0.01) + self.beta.data.normal_(0.1, 0.01) + weights = torch.load("test/DISTS_weights.pt") + self.alpha.data = weights["alpha"] + self.beta.data = weights["beta"] + + for param in self.parameters(): + param.requires_grad = False + + def forward_once(self, x): + h = (x - self.mean) / self.std + h = self.stage1(h) + h_relu1_2 = h + h = self.stage2(h) + h_relu2_2 = h + h = self.stage3(h) + h_relu3_3 = h + h = self.stage4(h) + h_relu4_3 = h + h = self.stage5(h) + h_relu5_3 = h + return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3] + + def forward(self, x, y, require_grad=False, batch_average=False): + if require_grad: + feats0 = self.forward_once(x) + feats1 = self.forward_once(y) + else: + with torch.no_grad(): + feats0 = self.forward_once(x) + feats1 = self.forward_once(y) + dist1 = 0 + dist2 = 0 + c1 = 1e-6 + c2 = 1e-6 + w_sum = self.alpha.sum() + self.beta.sum() + alpha = torch.split(self.alpha / w_sum, self.chns, dim=1) + beta = torch.split(self.beta / w_sum, self.chns, dim=1) + + for k in range(len(self.chns)): + x_mean = feats0[k].mean([2, 3], keepdim=True) + y_mean = feats1[k].mean([2, 3], keepdim=True) + S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1) + dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True) + + x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True) + y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True) + xy_cov = (feats0[k] * feats1[k]).mean( + [2, 3], keepdim=True + ) - x_mean * y_mean + S2 = (2 * xy_cov + c2) / (x_var + y_var + c2) + dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True) + + score = 1 - (dist1 + dist2).squeeze() + + if batch_average: + return score.mean() + else: + return score + + import os, math, gc import torchvision as vision import torch @@ -12,155 +144,181 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.strategies import DeepSpeedStrategy import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam -from pytorch_msssim import SSIM -class To2Bin(torch.autograd.Function): + +class ToBinary(torch.autograd.Function): @staticmethod def forward(ctx, x): - return torch.floor(x + torch.empty_like(x).uniform_(0, 1)) + return torch.floor(x + torch.empty_like(x).uniform_(0.4, 0.6)) + @staticmethod def backward(ctx, grad_output): return grad_output.clone() -def __nop(ob): - return ob -MyModule = nn.Module -MyFunction = __nop -if os.environ["RWKV_JIT_ON"] == "1": - MyModule = torch.jit.ScriptModule - MyFunction = torch.jit.script_method +MyModule = torch.jit.ScriptModule +MyFunction = torch.jit.script_method -class RWKV_IMG(pl.LightningModule): + +class R_ENCODER(MyModule): def __init__(self, args): super().__init__() self.args = args - self.e0b0 = nn.BatchNorm2d(12) - self.e0w0 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) - self.e0b1 = nn.BatchNorm2d(12) - self.e0w1 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) - self.e0b2 = nn.BatchNorm2d(12) - self.e0w2 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) - self.e0b3 = nn.BatchNorm2d(12) - self.e0w3 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) - - self.e1b0 = nn.BatchNorm2d(48) - self.e1w0 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) - self.e1b1 = nn.BatchNorm2d(48) - self.e1w1 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) - self.e1b2 = nn.BatchNorm2d(48) - self.e1w2 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) - self.e1b3 = nn.BatchNorm2d(48) - self.e1w3 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) - - self.e2b0 = nn.BatchNorm2d(192) - self.e2w0 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) - self.e2b1 = nn.BatchNorm2d(192) - self.e2w1 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) - self.e2b2 = nn.BatchNorm2d(192) - self.e2w2 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) - self.e2b3 = nn.BatchNorm2d(192) - self.e2w3 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) - - self.ewww = nn.Conv2d(192, 8, kernel_size = 3, stride = 1, padding = 1) - - self.dwww = nn.Conv2d(8, 192, kernel_size = 3, stride = 1, padding = 1) - - self.d0b0 = nn.BatchNorm2d(192) - self.d0w0 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) - self.d0b1 = nn.BatchNorm2d(192) - self.d0w1 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) - self.d0b2 = nn.BatchNorm2d(192) - self.d0w2 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) - self.d0b3 = nn.BatchNorm2d(192) - self.d0w3 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1) - - self.d1b0 = nn.BatchNorm2d(48) - self.d1w0 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) - self.d1b1 = nn.BatchNorm2d(48) - self.d1w1 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) - self.d1b2 = nn.BatchNorm2d(48) - self.d1w2 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) - self.d1b3 = nn.BatchNorm2d(48) - self.d1w3 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1) - - self.d2b0 = nn.BatchNorm2d(12) - self.d2w0 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) - self.d2b1 = nn.BatchNorm2d(12) - self.d2w1 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) - self.d2b2 = nn.BatchNorm2d(12) - self.d2w2 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) - self.d2b3 = nn.BatchNorm2d(12) - self.d2w3 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1) - - self.SSIM = SSIM(data_range=1, size_average=True, channel=3) + self.B00 = nn.BatchNorm2d(12) + self.C00 = nn.Conv2d(12, 96, kernel_size=3, padding=1) + self.C01 = nn.Conv2d(96, 12, kernel_size=3, padding=1) + self.C02 = nn.Conv2d(12, 96, kernel_size=3, padding=1) + self.C03 = nn.Conv2d(96, 12, kernel_size=3, padding=1) - def configure_optimizers(self): - args = self.args - optim_groups = [ - {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, - ] - if self.deepspeed_offload: - return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) - return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) - # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + self.B10 = nn.BatchNorm2d(48) + self.C10 = nn.Conv2d(48, 192, kernel_size=3, padding=1) + self.C11 = nn.Conv2d(192, 48, kernel_size=3, padding=1) + self.C12 = nn.Conv2d(48, 192, kernel_size=3, padding=1) + self.C13 = nn.Conv2d(192, 48, kernel_size=3, padding=1) - @property - def deepspeed_offload(self) -> bool: - strategy = self.trainer.strategy - if isinstance(strategy, DeepSpeedStrategy): - config = strategy.config["zero_optimization"] - return config.get("offload_optimizer") or config.get("offload_param") - return False + self.B20 = nn.BatchNorm2d(192) + self.C20 = nn.Conv2d(192, 192, kernel_size=3, padding=1) + self.C21 = nn.Conv2d(192, 192, kernel_size=3, padding=1) + self.C22 = nn.Conv2d(192, 192, kernel_size=3, padding=1) + self.C23 = nn.Conv2d(192, 192, kernel_size=3, padding=1) - def forward(self, img): - x = img + self.COUT = nn.Conv2d(192, 8, kernel_size=3, padding=1) + + @MyFunction + def forward(self, x): + ACT = F.silu x = F.pixel_unshuffle(x, 2) - x = x + self.e0w1(F.mish(self.e0b1(self.e0w0(F.mish(self.e0b0(x)))))) - x = x + self.e0w3(F.mish(self.e0b3(self.e0w2(F.mish(self.e0b2(x)))))) + x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) + x = x + self.C03(ACT(self.C02(x))) x = F.pixel_unshuffle(x, 2) - x = x + self.e1w1(F.mish(self.e1b1(self.e1w0(F.mish(self.e1b0(x)))))) - x = x + self.e1w3(F.mish(self.e1b3(self.e1w2(F.mish(self.e1b2(x)))))) + x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) + x = x + self.C13(ACT(self.C12(x))) x = F.pixel_unshuffle(x, 2) - x = x + self.e2w1(F.mish(self.e2b1(self.e2w0(F.mish(self.e2b0(x)))))) - x = x + self.e2w3(F.mish(self.e2b3(self.e2w2(F.mish(self.e2b2(x)))))) + x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) + x = x + self.C23(ACT(self.C22(x))) + + x = self.COUT(x) + + return torch.sigmoid(x) + - x = self.ewww(x) +class R_DECODER(MyModule): + def __init__(self, args): + super().__init__() + self.args = args + + self.CIN = nn.Conv2d(8, 192, kernel_size=3, padding=1) + + self.B00 = nn.BatchNorm2d(192) + self.C00 = nn.Conv2d(192, 192, kernel_size=3, padding=1) + self.C01 = nn.Conv2d(192, 192, kernel_size=3, padding=1) + self.C02 = nn.Conv2d(192, 192, kernel_size=3, padding=1) + self.C03 = nn.Conv2d(192, 192, kernel_size=3, padding=1) - x = To2Bin.apply(torch.sigmoid(x)) - # print(x.shape, x) + self.B10 = nn.BatchNorm2d(48) + self.C10 = nn.Conv2d(48, 192, kernel_size=3, padding=1) + self.C11 = nn.Conv2d(192, 48, kernel_size=3, padding=1) + self.C12 = nn.Conv2d(48, 192, kernel_size=3, padding=1) + self.C13 = nn.Conv2d(192, 48, kernel_size=3, padding=1) - x = self.dwww(x) + self.B20 = nn.BatchNorm2d(12) + self.C20 = nn.Conv2d(12, 96, kernel_size=3, padding=1) + self.C21 = nn.Conv2d(96, 12, kernel_size=3, padding=1) + self.C22 = nn.Conv2d(12, 96, kernel_size=3, padding=1) + self.C23 = nn.Conv2d(96, 12, kernel_size=3, padding=1) - x = x + self.d0w1(F.mish(self.d0b1(self.d0w0(F.mish(self.d0b0(x)))))) - x = x + self.d0w3(F.mish(self.d0b3(self.d0w2(F.mish(self.d0b2(x)))))) + @MyFunction + def forward(self, x): + ACT = F.silu + + x = self.CIN(x) + + x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) + x = x + self.C03(ACT(self.C02(x))) x = F.pixel_shuffle(x, 2) - x = x + self.d1w1(F.mish(self.d1b1(self.d1w0(F.mish(self.d1b0(x)))))) - x = x + self.d1w3(F.mish(self.d1b3(self.d1w2(F.mish(self.d1b2(x)))))) + x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) + x = x + self.C13(ACT(self.C12(x))) x = F.pixel_shuffle(x, 2) - x = x + self.d2w1(F.mish(self.d2b1(self.d2w0(F.mish(self.d2b0(x)))))) - x = x + self.d2w3(F.mish(self.d2b3(self.d2w2(F.mish(self.d2b2(x)))))) + x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) + x = x + self.C23(ACT(self.C22(x))) x = F.pixel_shuffle(x, 2) - x = torch.sigmoid(x) - return x + return torch.sigmoid(x) + + +class RWKV_IMG(pl.LightningModule): + def __init__(self, args): + super().__init__() + self.args = args + self.encoder = R_ENCODER(args) + self.decoder = R_DECODER(args) + self.loss_dists = DISTS() + + def configure_optimizers(self): + args = self.args + optim_groups = [ + {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, + ] + if self.deepspeed_offload: + return DeepSpeedCPUAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adamw_mode=False, + weight_decay=0, + amsgrad=False, + ) + return FusedAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adam_w_mode=False, + weight_decay=0, + amsgrad=False, + ) + # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + + @property + def deepspeed_offload(self) -> bool: + strategy = self.trainer.strategy + if isinstance(strategy, DeepSpeedStrategy): + config = strategy.config["zero_optimization"] + return config.get("offload_optimizer") or config.get("offload_param") + return False + + def forward(self, img): + z = self.encoder(img) + z = ToBinary.apply(z) + out = self.decoder(z) + return out def training_step(self, batch, batch_idx): args = self.args img, txt = batch out = self(img) if self.trainer.is_global_zero: - if (self.trainer.global_step+1) % (100 * int(args.devices)) == 0: - vision.utils.save_image(img[:4], f"test/image_model/{self.trainer.global_step}-src.jpg") - vision.utils.save_image(out[:4], f"test/image_model/{self.trainer.global_step}-out.jpg") + if (self.trainer.global_step + 1) % (100 * int(args.devices)) == 0: + vision.utils.save_image( + img[:4], f"test/image_model/{self.trainer.global_step}-src.jpg" + ) + vision.utils.save_image( + out[:4], f"test/image_model/{self.trainer.global_step}-out.jpg" + ) - return 1 - self.SSIM(out.float(), img.float()) + loss_l1 = F.l1_loss(out, img) + loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True) + + return loss_l1 + loss_dists def training_step_end(self, batch_parts): all = self.all_gather(batch_parts) @@ -181,7 +339,7 @@ class RWKV_IMG(pl.LightningModule): for n in self.state_dict(): p = self.state_dict()[n] shape = p.shape - + m[n] = p m[n] = m[n].cpu()