CLIP-guided Binary AutoEncoder

main
BlinkDL 3 years ago
parent bb59fffac1
commit 74fedc0d86

@ -87,9 +87,13 @@ class MyDataset(Dataset):
return x return x
import webdataset as wds import webdataset as wds
import torchvision.transforms as transforms import torchvision.transforms as transforms
img_transform = transforms.Compose( # img_transform = transforms.Compose(
[transforms.CenterCrop(256)] # [transforms.CenterCrop(256)]
) # )
img_transform = transforms.Compose([
transforms.CenterCrop(512),
transforms.Resize((args.my_img_size))
])
self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
for pp in self.data_raw.pipeline: for pp in self.data_raw.pipeline:
if 'Resampled' in str(pp): if 'Resampled' in str(pp):

@ -8,7 +8,7 @@ import torch
from torchvision import models from torchvision import models
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import clip
class L2pooling(nn.Module): class L2pooling(nn.Module):
def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0): def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
@ -134,6 +134,8 @@ class DISTS(torch.nn.Module):
return score return score
########################################################################################################
import os, math, gc import os, math, gc
import torchvision as vision import torchvision as vision
import torch import torch
@ -144,6 +146,7 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from pytorch_msssim import MS_SSIM
class ToBinary(torch.autograd.Function): class ToBinary(torch.autograd.Function):
@ -159,6 +162,7 @@ class ToBinary(torch.autograd.Function):
MyModule = torch.jit.ScriptModule MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method MyFunction = torch.jit.script_method
########################################################################################################
class R_ENCODER(MyModule): class R_ENCODER(MyModule):
def __init__(self, args): def __init__(self, args):
@ -183,7 +187,7 @@ class R_ENCODER(MyModule):
self.C22 = nn.Conv2d(192, 192, kernel_size=3, padding=1) self.C22 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(192, 192, kernel_size=3, padding=1) self.C23 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(192, 8, kernel_size=3, padding=1) self.COUT = nn.Conv2d(192, args.my_img_bit, kernel_size=3, padding=1)
@MyFunction @MyFunction
def forward(self, x): def forward(self, x):
@ -205,13 +209,14 @@ class R_ENCODER(MyModule):
return torch.sigmoid(x) return torch.sigmoid(x)
########################################################################################################
class R_DECODER(MyModule): class R_DECODER(MyModule):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.args = args self.args = args
self.CIN = nn.Conv2d(8, 192, kernel_size=3, padding=1) self.CIN = nn.Conv2d(args.my_img_bit, 192, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(192) self.B00 = nn.BatchNorm2d(192)
self.C00 = nn.Conv2d(192, 192, kernel_size=3, padding=1) self.C00 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
@ -251,14 +256,31 @@ class R_DECODER(MyModule):
return torch.sigmoid(x) return torch.sigmoid(x)
########################################################################################################
class RWKV_IMG(pl.LightningModule): class RWKV_IMG(pl.LightningModule):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.args = args self.args = args
self.encoder = R_ENCODER(args) self.encoder = R_ENCODER(args)
self.decoder = R_DECODER(args) self.decoder = R_DECODER(args)
clip_name = args.my_img_clip
if clip_name == 'B32':
clip_name = 'ViT-B/32'
elif clip_name == 'B16':
clip_name = 'ViT-B/16'
elif clip_name == 'L14':
clip_name = 'ViT-L/14'
self.clip_model, self.clip_preprocess = clip.load(clip_name, jit = True)
for n, p in self.named_parameters():
if 'clip_model' in n:
p.requires_grad = False
self.loss_dists = DISTS() self.loss_dists = DISTS()
# self.loss_ssim = MS_SSIM(data_range=1, size_average=True, channel=3)
def configure_optimizers(self): def configure_optimizers(self):
args = self.args args = self.args
@ -308,17 +330,25 @@ class RWKV_IMG(pl.LightningModule):
out = self(img) out = self(img)
if self.trainer.is_global_zero: if self.trainer.is_global_zero:
if (self.trainer.global_step + 1) % (100 * int(args.devices)) == 0: if (self.trainer.global_step + 1) % (100 * int(args.devices)) == 0:
img_dir = f"test/image_model/{args.run_name}"
if not os.path.exists(img_dir):
os.makedirs(img_dir)
vision.utils.save_image( vision.utils.save_image(
img[:4], f"test/image_model/{self.trainer.global_step}-src.jpg" img[:4], f"{img_dir}/{self.trainer.global_step}-src.jpg"#, padding=0
) )
vision.utils.save_image( vision.utils.save_image(
out[:4], f"test/image_model/{self.trainer.global_step}-out.jpg" out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg"#, padding=0
) )
loss_l1 = F.l1_loss(out, img) # loss_l1 = F.l1_loss(out, img)
# loss_ssim = 1 - self.loss_ssim(out, img)
# return loss_dists# * 1# + loss_l1 * 1 + # + loss_ssim * 0.4
loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True) loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True)
return loss_l1 + loss_dists loss_clip = F.mse_loss(self.clip_model.encode_image(img), self.clip_model.encode_image(out))
return loss_dists + loss_clip * args.my_img_clip_scale
def training_step_end(self, batch_parts): def training_step_end(self, batch_parts):
all = self.all_gather(batch_parts) all = self.all_gather(batch_parts)

@ -61,11 +61,9 @@ class train_callback(pl.Callback):
if len(args.wandb) > 0: if len(args.wandb) > 0:
print("Login to wandb...") print("Login to wandb...")
import wandb import wandb
model_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
wandb.init( wandb.init(
project=args.wandb, project=args.wandb,
name=model_name + " " + args.my_timestamp, name=args.run_name + " " + args.my_timestamp,
config=args, config=args,
save_code=False, save_code=False,
) )

@ -3,24 +3,10 @@
######################################################################################################## ########################################################################################################
if __name__ == "__main__": if __name__ == "__main__":
print("########## work in progress ##########")
import os, warnings, math, datetime, sys, time
import numpy as np
from argparse import ArgumentParser from argparse import ArgumentParser
import torch
from torch.utils.data import DataLoader
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning import seed_everything
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
# print("WARNING: THIS IS ONLY FOR DEBUG") print("########## work in progress ##########")
# seed_everything(42)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
######################################################################################################## ########################################################################################################
# #
@ -61,11 +47,11 @@ if __name__ == "__main__":
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1 # --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
parser = ArgumentParser() parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
parser.add_argument("--load_model", default="", type=str) # full path, with .pth parser.add_argument("--load_model", default="", type=str) # full path, with .pth
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
parser.add_argument("--proj_dir", default="out", type=str) parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--random_seed", default="-1", type=int)
parser.add_argument("--data_file", default="", type=str) parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str) parser.add_argument("--data_type", default="utf-8", type=str)
@ -98,7 +84,35 @@ if __name__ == "__main__":
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser.add_argument("--my_img_version", default=0, type=int)
parser.add_argument("--my_img_size", default=0, type=int)
parser.add_argument("--my_img_bit", default=0, type=int)
parser.add_argument("--my_img_clip", default='x', type=str)
parser.add_argument("--my_img_clip_scale", default=1, type=float)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args() args = parser.parse_args()
########################################################################################################
import os, warnings, math, datetime, sys, time
import numpy as np
import torch
from torch.utils.data import DataLoader
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED SET TO f{args.random_seed} ##########\n" * 3)
seed_everything(args.random_seed)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
# os.environ["WDS_SHOW_SEED"] = "1"
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
args.enable_checkpointing = False args.enable_checkpointing = False
args.replace_sampler_ddp = False args.replace_sampler_ddp = False
@ -112,6 +126,11 @@ if __name__ == "__main__":
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
os.environ["RWKV_T_MAX"] = str(args.ctx_len) os.environ["RWKV_T_MAX"] = str(args.ctx_len)
if args.data_type == "wds_img":
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
else:
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
if not os.path.exists(args.proj_dir): if not os.path.exists(args.proj_dir):
os.makedirs(args.proj_dir) os.makedirs(args.proj_dir)

Loading…
Cancel
Save