From 74fedc0d8615334ff48fb85fcb882dc6513afa18 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Tue, 20 Sep 2022 06:51:15 +0000 Subject: [PATCH] CLIP-guided Binary AutoEncoder --- RWKV-v4neo/src/dataset.py | 10 +++++--- RWKV-v4neo/src/model_img.py | 44 +++++++++++++++++++++++++++----- RWKV-v4neo/src/trainer.py | 4 +-- RWKV-v4neo/train.py | 51 +++++++++++++++++++++++++------------ 4 files changed, 80 insertions(+), 29 deletions(-) diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index c42b0be..37a5f40 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -87,9 +87,13 @@ class MyDataset(Dataset): return x import webdataset as wds import torchvision.transforms as transforms - img_transform = transforms.Compose( - [transforms.CenterCrop(256)] - ) + # img_transform = transforms.Compose( + # [transforms.CenterCrop(256)] + # ) + img_transform = transforms.Compose([ + transforms.CenterCrop(512), + transforms.Resize((args.my_img_size)) + ]) self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) for pp in self.data_raw.pipeline: if 'Resampled' in str(pp): diff --git a/RWKV-v4neo/src/model_img.py b/RWKV-v4neo/src/model_img.py index a451921..b592103 100644 --- a/RWKV-v4neo/src/model_img.py +++ b/RWKV-v4neo/src/model_img.py @@ -8,7 +8,7 @@ import torch from torchvision import models import torch.nn as nn import torch.nn.functional as F - +import clip class L2pooling(nn.Module): def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0): @@ -134,6 +134,8 @@ class DISTS(torch.nn.Module): return score +######################################################################################################## + import os, math, gc import torchvision as vision import torch @@ -144,6 +146,7 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.strategies import DeepSpeedStrategy import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam +# from pytorch_msssim import MS_SSIM class ToBinary(torch.autograd.Function): @@ -159,6 +162,7 @@ class ToBinary(torch.autograd.Function): MyModule = torch.jit.ScriptModule MyFunction = torch.jit.script_method +######################################################################################################## class R_ENCODER(MyModule): def __init__(self, args): @@ -183,7 +187,7 @@ class R_ENCODER(MyModule): self.C22 = nn.Conv2d(192, 192, kernel_size=3, padding=1) self.C23 = nn.Conv2d(192, 192, kernel_size=3, padding=1) - self.COUT = nn.Conv2d(192, 8, kernel_size=3, padding=1) + self.COUT = nn.Conv2d(192, args.my_img_bit, kernel_size=3, padding=1) @MyFunction def forward(self, x): @@ -205,13 +209,14 @@ class R_ENCODER(MyModule): return torch.sigmoid(x) +######################################################################################################## class R_DECODER(MyModule): def __init__(self, args): super().__init__() self.args = args - self.CIN = nn.Conv2d(8, 192, kernel_size=3, padding=1) + self.CIN = nn.Conv2d(args.my_img_bit, 192, kernel_size=3, padding=1) self.B00 = nn.BatchNorm2d(192) self.C00 = nn.Conv2d(192, 192, kernel_size=3, padding=1) @@ -251,14 +256,31 @@ class R_DECODER(MyModule): return torch.sigmoid(x) +######################################################################################################## class RWKV_IMG(pl.LightningModule): def __init__(self, args): super().__init__() self.args = args + self.encoder = R_ENCODER(args) self.decoder = R_DECODER(args) + + clip_name = args.my_img_clip + if clip_name == 'B32': + clip_name = 'ViT-B/32' + elif clip_name == 'B16': + clip_name = 'ViT-B/16' + elif clip_name == 'L14': + clip_name = 'ViT-L/14' + self.clip_model, self.clip_preprocess = clip.load(clip_name, jit = True) + + for n, p in self.named_parameters(): + if 'clip_model' in n: + p.requires_grad = False + self.loss_dists = DISTS() + # self.loss_ssim = MS_SSIM(data_range=1, size_average=True, channel=3) def configure_optimizers(self): args = self.args @@ -308,17 +330,25 @@ class RWKV_IMG(pl.LightningModule): out = self(img) if self.trainer.is_global_zero: if (self.trainer.global_step + 1) % (100 * int(args.devices)) == 0: + img_dir = f"test/image_model/{args.run_name}" + if not os.path.exists(img_dir): + os.makedirs(img_dir) vision.utils.save_image( - img[:4], f"test/image_model/{self.trainer.global_step}-src.jpg" + img[:4], f"{img_dir}/{self.trainer.global_step}-src.jpg"#, padding=0 ) vision.utils.save_image( - out[:4], f"test/image_model/{self.trainer.global_step}-out.jpg" + out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg"#, padding=0 ) - loss_l1 = F.l1_loss(out, img) + # loss_l1 = F.l1_loss(out, img) + # loss_ssim = 1 - self.loss_ssim(out, img) + # return loss_dists# * 1# + loss_l1 * 1 + # + loss_ssim * 0.4 + loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True) - return loss_l1 + loss_dists + loss_clip = F.mse_loss(self.clip_model.encode_image(img), self.clip_model.encode_image(out)) + + return loss_dists + loss_clip * args.my_img_clip_scale def training_step_end(self, batch_parts): all = self.all_gather(batch_parts) diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index 0648c6c..aed2c83 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -61,11 +61,9 @@ class train_callback(pl.Callback): if len(args.wandb) > 0: print("Login to wandb...") import wandb - - model_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" wandb.init( project=args.wandb, - name=model_name + " " + args.my_timestamp, + name=args.run_name + " " + args.my_timestamp, config=args, save_code=False, ) diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 7daacb3..4423c33 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -3,24 +3,10 @@ ######################################################################################################## if __name__ == "__main__": - print("########## work in progress ##########") - import os, warnings, math, datetime, sys, time - import numpy as np from argparse import ArgumentParser - import torch - from torch.utils.data import DataLoader - import deepspeed - import pytorch_lightning as pl from pytorch_lightning import Trainer - from pytorch_lightning import seed_everything - from pytorch_lightning.utilities import rank_zero_info, rank_zero_only - # print("WARNING: THIS IS ONLY FOR DEBUG") - # seed_everything(42) - - np.set_printoptions(precision=4, suppress=True, linewidth=200) - warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") - warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") + print("########## work in progress ##########") ######################################################################################################## # @@ -61,11 +47,11 @@ if __name__ == "__main__": # --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1 parser = ArgumentParser() - parser = Trainer.add_argparse_args(parser) parser.add_argument("--load_model", default="", type=str) # full path, with .pth parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb parser.add_argument("--proj_dir", default="out", type=str) + parser.add_argument("--random_seed", default="-1", type=int) parser.add_argument("--data_file", default="", type=str) parser.add_argument("--data_type", default="utf-8", type=str) @@ -98,7 +84,35 @@ if __name__ == "__main__": parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) + parser.add_argument("--my_img_version", default=0, type=int) + parser.add_argument("--my_img_size", default=0, type=int) + parser.add_argument("--my_img_bit", default=0, type=int) + parser.add_argument("--my_img_clip", default='x', type=str) + parser.add_argument("--my_img_clip_scale", default=1, type=float) + + parser = Trainer.add_argparse_args(parser) args = parser.parse_args() + + ######################################################################################################## + + import os, warnings, math, datetime, sys, time + import numpy as np + import torch + from torch.utils.data import DataLoader + import deepspeed + import pytorch_lightning as pl + from pytorch_lightning import seed_everything + from pytorch_lightning.utilities import rank_zero_info, rank_zero_only + + if args.random_seed >= 0: + print(f"########## WARNING: GLOBAL SEED SET TO f{args.random_seed} ##########\n" * 3) + seed_everything(args.random_seed) + + np.set_printoptions(precision=4, suppress=True, linewidth=200) + warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") + warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") + # os.environ["WDS_SHOW_SEED"] = "1" + args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") args.enable_checkpointing = False args.replace_sampler_ddp = False @@ -112,6 +126,11 @@ if __name__ == "__main__": args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz os.environ["RWKV_T_MAX"] = str(args.ctx_len) + if args.data_type == "wds_img": + args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}" + args.proj_dir = f"{args.proj_dir}-{args.run_name}" + else: + args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" if not os.path.exists(args.proj_dir): os.makedirs(args.proj_dir)