|
|
|
|
@ -9,6 +9,7 @@ from torchvision import models
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import clip
|
|
|
|
|
from transformers import CLIPModel
|
|
|
|
|
|
|
|
|
|
class L2pooling(nn.Module):
|
|
|
|
|
def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
|
|
|
|
|
@ -151,12 +152,17 @@ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
|
|
|
|
|
|
|
|
|
class ToBinary(torch.autograd.Function):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(ctx, x):
|
|
|
|
|
return torch.floor(x + torch.empty_like(x).uniform_(0.4, 0.6))
|
|
|
|
|
def forward(ctx, x):#, noise_scale):
|
|
|
|
|
# if noise_scale > 0:
|
|
|
|
|
# noise_min = 0.5 - noise_scale / 2
|
|
|
|
|
# noise_max = 0.5 + noise_scale / 2
|
|
|
|
|
# return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max))
|
|
|
|
|
# else:
|
|
|
|
|
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, grad_output):
|
|
|
|
|
return grad_output.clone()
|
|
|
|
|
return grad_output.clone()#, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MyModule = torch.jit.ScriptModule
|
|
|
|
|
@ -258,6 +264,11 @@ class R_DECODER(MyModule):
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
def cosine_loss(x, y):
|
|
|
|
|
x = F.normalize(x, dim=-1)
|
|
|
|
|
y = F.normalize(y, dim=-1)
|
|
|
|
|
return 1 - torch.einsum('ij,ij->i',[x,y])
|
|
|
|
|
|
|
|
|
|
class RWKV_IMG(pl.LightningModule):
|
|
|
|
|
def __init__(self, args):
|
|
|
|
|
super().__init__()
|
|
|
|
|
@ -266,6 +277,7 @@ class RWKV_IMG(pl.LightningModule):
|
|
|
|
|
self.encoder = R_ENCODER(args)
|
|
|
|
|
self.decoder = R_DECODER(args)
|
|
|
|
|
|
|
|
|
|
self.clip_model = None
|
|
|
|
|
clip_name = args.my_img_clip
|
|
|
|
|
if clip_name == 'B32':
|
|
|
|
|
clip_name = 'ViT-B/32'
|
|
|
|
|
@ -273,7 +285,18 @@ class RWKV_IMG(pl.LightningModule):
|
|
|
|
|
clip_name = 'ViT-B/16'
|
|
|
|
|
elif clip_name == 'L14':
|
|
|
|
|
clip_name = 'ViT-L/14'
|
|
|
|
|
self.clip_model, self.clip_preprocess = clip.load(clip_name, jit = True)
|
|
|
|
|
elif clip_name == 'OB32':
|
|
|
|
|
clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
|
|
|
|
self.clip_model = CLIPModel.from_pretrained(clip_name)
|
|
|
|
|
self.clip_model.encode_image = self.clip_model.get_image_features
|
|
|
|
|
if self.clip_model == None:
|
|
|
|
|
self.clip_model, _ = clip.load(clip_name, jit = True)
|
|
|
|
|
self.register_buffer(
|
|
|
|
|
"clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, -1, 1, 1)
|
|
|
|
|
)
|
|
|
|
|
self.register_buffer(
|
|
|
|
|
"clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, -1, 1, 1)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for n, p in self.named_parameters():
|
|
|
|
|
if 'clip_model' in n:
|
|
|
|
|
@ -320,7 +343,7 @@ class RWKV_IMG(pl.LightningModule):
|
|
|
|
|
|
|
|
|
|
def forward(self, img):
|
|
|
|
|
z = self.encoder(img)
|
|
|
|
|
z = ToBinary.apply(z)
|
|
|
|
|
z = ToBinary.apply(z)#, self.args.my_img_noise_scale)
|
|
|
|
|
out = self.decoder(z)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
@ -340,15 +363,18 @@ class RWKV_IMG(pl.LightningModule):
|
|
|
|
|
out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg"#, padding=0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# loss_l1 = F.l1_loss(out, img)
|
|
|
|
|
# loss_ssim = 1 - self.loss_ssim(out, img)
|
|
|
|
|
# return loss_dists# * 1# + loss_l1 * 1 + # + loss_ssim * 0.4
|
|
|
|
|
|
|
|
|
|
loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True)
|
|
|
|
|
|
|
|
|
|
loss_clip = F.mse_loss(self.clip_model.encode_image(img), self.clip_model.encode_image(out))
|
|
|
|
|
iii = self.clip_model.encode_image((img - self.clip_mean) / self.clip_std)
|
|
|
|
|
ooo = self.clip_model.encode_image((out - self.clip_mean) / self.clip_std)
|
|
|
|
|
loss_clip = torch.mean(cosine_loss(iii, ooo))
|
|
|
|
|
|
|
|
|
|
return loss_dists + loss_clip * args.my_img_clip_scale
|
|
|
|
|
if args.my_img_l1_scale > 0:
|
|
|
|
|
loss_l1 = F.l1_loss(out, img)
|
|
|
|
|
return loss_dists + loss_clip * args.my_img_clip_scale + loss_l1 * args.my_img_l1_scale
|
|
|
|
|
else:
|
|
|
|
|
return loss_dists + loss_clip * args.my_img_clip_scale
|
|
|
|
|
|
|
|
|
|
def training_step_end(self, batch_parts):
|
|
|
|
|
all = self.all_gather(batch_parts)
|
|
|
|
|
|