|
|
|
|
@ -3,11 +3,24 @@
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import os
|
|
|
|
|
import os, math, gc
|
|
|
|
|
import torch
|
|
|
|
|
from torchvision import models
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import torchvision as vision
|
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
|
|
|
|
from pytorch_lightning.strategies import DeepSpeedStrategy
|
|
|
|
|
import deepspeed
|
|
|
|
|
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
|
|
|
|
# from pytorch_msssim import MS_SSIM
|
|
|
|
|
|
|
|
|
|
def __nop(ob):
|
|
|
|
|
return ob
|
|
|
|
|
MyModule = torch.jit.ScriptModule
|
|
|
|
|
# MyFunction = __nop
|
|
|
|
|
MyFunction = torch.jit.script_method
|
|
|
|
|
|
|
|
|
|
import clip
|
|
|
|
|
from transformers import CLIPModel
|
|
|
|
|
|
|
|
|
|
@ -39,7 +52,7 @@ class L2pooling(nn.Module):
|
|
|
|
|
class DISTS(torch.nn.Module):
|
|
|
|
|
def __init__(self, load_weights=True):
|
|
|
|
|
super(DISTS, self).__init__()
|
|
|
|
|
vgg_pretrained_features = models.vgg16(
|
|
|
|
|
vgg_pretrained_features = vision.models.vgg16(
|
|
|
|
|
weights="VGG16_Weights.IMAGENET1K_V1"
|
|
|
|
|
).features
|
|
|
|
|
self.stage1 = torch.nn.Sequential()
|
|
|
|
|
@ -134,39 +147,19 @@ class DISTS(torch.nn.Module):
|
|
|
|
|
else:
|
|
|
|
|
return score
|
|
|
|
|
|
|
|
|
|
class ToBinary(torch.autograd.Function):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(ctx, x):#, noise_scale):
|
|
|
|
|
# if noise_scale > 0:
|
|
|
|
|
# noise_min = 0.5 - noise_scale / 2
|
|
|
|
|
# noise_max = 0.5 + noise_scale / 2
|
|
|
|
|
# return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max))
|
|
|
|
|
# else:
|
|
|
|
|
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
import os, math, gc
|
|
|
|
|
import torchvision as vision
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
|
|
|
|
from pytorch_lightning.strategies import DeepSpeedStrategy
|
|
|
|
|
import deepspeed
|
|
|
|
|
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
|
|
|
|
# from pytorch_msssim import MS_SSIM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ToBinary(torch.autograd.Function):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(ctx, x):#, noise_scale):
|
|
|
|
|
# if noise_scale > 0:
|
|
|
|
|
# noise_min = 0.5 - noise_scale / 2
|
|
|
|
|
# noise_max = 0.5 + noise_scale / 2
|
|
|
|
|
# return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max))
|
|
|
|
|
# else:
|
|
|
|
|
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, grad_output):
|
|
|
|
|
return grad_output.clone()#, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MyModule = torch.jit.ScriptModule
|
|
|
|
|
MyFunction = torch.jit.script_method
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, grad_output):
|
|
|
|
|
return grad_output.clone()#, None
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
@ -174,30 +167,45 @@ class R_ENCODER(MyModule):
|
|
|
|
|
def __init__(self, args):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.args = args
|
|
|
|
|
|
|
|
|
|
self.B00 = nn.BatchNorm2d(12)
|
|
|
|
|
self.C00 = nn.Conv2d(12, 96, kernel_size=3, padding=1)
|
|
|
|
|
self.C01 = nn.Conv2d(96, 12, kernel_size=3, padding=1)
|
|
|
|
|
self.C02 = nn.Conv2d(12, 96, kernel_size=3, padding=1)
|
|
|
|
|
self.C03 = nn.Conv2d(96, 12, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.B10 = nn.BatchNorm2d(48)
|
|
|
|
|
self.C10 = nn.Conv2d(48, 192, kernel_size=3, padding=1)
|
|
|
|
|
self.C11 = nn.Conv2d(192, 48, kernel_size=3, padding=1)
|
|
|
|
|
self.C12 = nn.Conv2d(48, 192, kernel_size=3, padding=1)
|
|
|
|
|
self.C13 = nn.Conv2d(192, 48, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.B20 = nn.BatchNorm2d(192)
|
|
|
|
|
self.C20 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
|
|
|
|
|
self.C21 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
|
|
|
|
|
self.C22 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
|
|
|
|
|
self.C23 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.COUT = nn.Conv2d(192, args.my_img_bit, kernel_size=3, padding=1)
|
|
|
|
|
dd = 8
|
|
|
|
|
self.Bxx = nn.BatchNorm2d(dd*64)
|
|
|
|
|
|
|
|
|
|
self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
|
|
|
|
|
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
|
|
|
|
|
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.B00 = nn.BatchNorm2d(dd*4)
|
|
|
|
|
self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
|
|
|
|
self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
|
|
|
|
self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
|
|
|
|
self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.B10 = nn.BatchNorm2d(dd*16)
|
|
|
|
|
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
|
|
|
|
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
|
|
|
|
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
|
|
|
|
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.B20 = nn.BatchNorm2d(dd*64)
|
|
|
|
|
self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
|
|
|
self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
|
|
|
self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
|
|
|
self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
|
|
|
# self.B21 = nn.BatchNorm2d(dd*64)
|
|
|
|
|
# self.C24 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
|
|
|
# self.C25 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
|
|
|
# self.C26 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
|
|
|
# self.C27 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
ACT = F.silu
|
|
|
|
|
def forward(self, img):
|
|
|
|
|
ACT = F.mish
|
|
|
|
|
|
|
|
|
|
x = self.CIN(img)
|
|
|
|
|
xx = self.Bxx(F.pixel_unshuffle(x, 8))
|
|
|
|
|
x = x + self.Cx1(ACT(self.Cx0(x)))
|
|
|
|
|
|
|
|
|
|
x = F.pixel_unshuffle(x, 2)
|
|
|
|
|
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
|
|
|
|
|
@ -210,9 +218,10 @@ class R_ENCODER(MyModule):
|
|
|
|
|
x = F.pixel_unshuffle(x, 2)
|
|
|
|
|
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
|
|
|
|
|
x = x + self.C23(ACT(self.C22(x)))
|
|
|
|
|
# x = x + self.C25(ACT(self.C24(ACT(self.B21(x)))))
|
|
|
|
|
# x = x + self.C27(ACT(self.C26(x)))
|
|
|
|
|
|
|
|
|
|
x = self.COUT(x)
|
|
|
|
|
|
|
|
|
|
x = self.COUT(x + xx)
|
|
|
|
|
return torch.sigmoid(x)
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
@ -221,35 +230,45 @@ class R_DECODER(MyModule):
|
|
|
|
|
def __init__(self, args):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.args = args
|
|
|
|
|
|
|
|
|
|
self.CIN = nn.Conv2d(args.my_img_bit, 192, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.B00 = nn.BatchNorm2d(192)
|
|
|
|
|
self.C00 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
|
|
|
|
|
self.C01 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
|
|
|
|
|
self.C02 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
|
|
|
|
|
self.C03 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.B10 = nn.BatchNorm2d(48)
|
|
|
|
|
self.C10 = nn.Conv2d(48, 192, kernel_size=3, padding=1)
|
|
|
|
|
self.C11 = nn.Conv2d(192, 48, kernel_size=3, padding=1)
|
|
|
|
|
self.C12 = nn.Conv2d(48, 192, kernel_size=3, padding=1)
|
|
|
|
|
self.C13 = nn.Conv2d(192, 48, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.B20 = nn.BatchNorm2d(12)
|
|
|
|
|
self.C20 = nn.Conv2d(12, 96, kernel_size=3, padding=1)
|
|
|
|
|
self.C21 = nn.Conv2d(96, 12, kernel_size=3, padding=1)
|
|
|
|
|
self.C22 = nn.Conv2d(12, 96, kernel_size=3, padding=1)
|
|
|
|
|
self.C23 = nn.Conv2d(96, 12, kernel_size=3, padding=1)
|
|
|
|
|
dd = 8
|
|
|
|
|
self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.B00 = nn.BatchNorm2d(dd*64)
|
|
|
|
|
self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
|
|
|
self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
|
|
|
self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
|
|
|
self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
|
|
|
# self.B01 = nn.BatchNorm2d(dd*64)
|
|
|
|
|
# self.C04 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
|
|
|
# self.C05 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
|
|
|
# self.C06 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
|
|
|
# self.C07 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.B10 = nn.BatchNorm2d(dd*16)
|
|
|
|
|
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
|
|
|
|
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
|
|
|
|
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
|
|
|
|
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.B20 = nn.BatchNorm2d(dd*4)
|
|
|
|
|
self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
|
|
|
|
self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
|
|
|
|
self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
|
|
|
|
self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
|
|
|
|
|
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
|
|
|
|
|
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
ACT = F.silu
|
|
|
|
|
|
|
|
|
|
x = self.CIN(x)
|
|
|
|
|
def forward(self, code):
|
|
|
|
|
ACT = F.mish
|
|
|
|
|
x = self.CIN(code)
|
|
|
|
|
|
|
|
|
|
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
|
|
|
|
|
x = x + self.C03(ACT(self.C02(x)))
|
|
|
|
|
# x = x + self.C05(ACT(self.C04(ACT(self.B01(x)))))
|
|
|
|
|
# x = x + self.C07(ACT(self.C06(x)))
|
|
|
|
|
x = F.pixel_shuffle(x, 2)
|
|
|
|
|
|
|
|
|
|
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
|
|
|
|
|
@ -260,9 +279,12 @@ class R_DECODER(MyModule):
|
|
|
|
|
x = x + self.C23(ACT(self.C22(x)))
|
|
|
|
|
x = F.pixel_shuffle(x, 2)
|
|
|
|
|
|
|
|
|
|
x = x + self.Cx1(ACT(self.Cx0(x)))
|
|
|
|
|
x = self.COUT(x)
|
|
|
|
|
|
|
|
|
|
return torch.sigmoid(x)
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
########################################################################################################`
|
|
|
|
|
|
|
|
|
|
def cosine_loss(x, y):
|
|
|
|
|
x = F.normalize(x, dim=-1)
|
|
|
|
|
@ -292,10 +314,10 @@ class RWKV_IMG(pl.LightningModule):
|
|
|
|
|
if self.clip_model == None:
|
|
|
|
|
self.clip_model, _ = clip.load(clip_name, jit = True)
|
|
|
|
|
self.register_buffer(
|
|
|
|
|
"clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, -1, 1, 1)
|
|
|
|
|
"clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1)
|
|
|
|
|
)
|
|
|
|
|
self.register_buffer(
|
|
|
|
|
"clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, -1, 1, 1)
|
|
|
|
|
"clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for n, p in self.named_parameters():
|
|
|
|
|
@ -393,8 +415,23 @@ class RWKV_IMG(pl.LightningModule):
|
|
|
|
|
)
|
|
|
|
|
m = {}
|
|
|
|
|
for n in self.state_dict():
|
|
|
|
|
scale = 1
|
|
|
|
|
p = self.state_dict()[n]
|
|
|
|
|
shape = p.shape
|
|
|
|
|
ss = n.split('.')
|
|
|
|
|
|
|
|
|
|
# if ss[0] in ['encoder', 'decoder']:
|
|
|
|
|
# if ss[2] == 'bias':
|
|
|
|
|
# scale = 0
|
|
|
|
|
# # elif n == 'encoder.CIN.weight':
|
|
|
|
|
# # nn.init.dirac_(p)
|
|
|
|
|
# else:
|
|
|
|
|
# try:
|
|
|
|
|
# if ss[1][0] == 'C' and (int(ss[1][2]) % 2 == 1):
|
|
|
|
|
# scale = 0
|
|
|
|
|
# except:
|
|
|
|
|
# pass
|
|
|
|
|
# m[n] = p * scale
|
|
|
|
|
|
|
|
|
|
m[n] = p
|
|
|
|
|
|
|
|
|
|
|