|
|
|
|
@ -8,7 +8,7 @@ import torch
|
|
|
|
|
from torchvision import models
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
import clip
|
|
|
|
|
|
|
|
|
|
class L2pooling(nn.Module):
|
|
|
|
|
def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
|
|
|
|
|
@ -134,6 +134,8 @@ class DISTS(torch.nn.Module):
|
|
|
|
|
return score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
import os, math, gc
|
|
|
|
|
import torchvision as vision
|
|
|
|
|
import torch
|
|
|
|
|
@ -144,6 +146,7 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
|
|
|
|
from pytorch_lightning.strategies import DeepSpeedStrategy
|
|
|
|
|
import deepspeed
|
|
|
|
|
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
|
|
|
|
# from pytorch_msssim import MS_SSIM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ToBinary(torch.autograd.Function):
|
|
|
|
|
@ -159,6 +162,7 @@ class ToBinary(torch.autograd.Function):
|
|
|
|
|
MyModule = torch.jit.ScriptModule
|
|
|
|
|
MyFunction = torch.jit.script_method
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
class R_ENCODER(MyModule):
|
|
|
|
|
def __init__(self, args):
|
|
|
|
|
@ -183,7 +187,7 @@ class R_ENCODER(MyModule):
|
|
|
|
|
self.C22 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
|
|
|
|
|
self.C23 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.COUT = nn.Conv2d(192, 8, kernel_size=3, padding=1)
|
|
|
|
|
self.COUT = nn.Conv2d(192, args.my_img_bit, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
@ -205,13 +209,14 @@ class R_ENCODER(MyModule):
|
|
|
|
|
|
|
|
|
|
return torch.sigmoid(x)
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
class R_DECODER(MyModule):
|
|
|
|
|
def __init__(self, args):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.args = args
|
|
|
|
|
|
|
|
|
|
self.CIN = nn.Conv2d(8, 192, kernel_size=3, padding=1)
|
|
|
|
|
self.CIN = nn.Conv2d(args.my_img_bit, 192, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
|
self.B00 = nn.BatchNorm2d(192)
|
|
|
|
|
self.C00 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
|
|
|
|
|
@ -251,14 +256,31 @@ class R_DECODER(MyModule):
|
|
|
|
|
|
|
|
|
|
return torch.sigmoid(x)
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
class RWKV_IMG(pl.LightningModule):
|
|
|
|
|
def __init__(self, args):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.args = args
|
|
|
|
|
|
|
|
|
|
self.encoder = R_ENCODER(args)
|
|
|
|
|
self.decoder = R_DECODER(args)
|
|
|
|
|
|
|
|
|
|
clip_name = args.my_img_clip
|
|
|
|
|
if clip_name == 'B32':
|
|
|
|
|
clip_name = 'ViT-B/32'
|
|
|
|
|
elif clip_name == 'B16':
|
|
|
|
|
clip_name = 'ViT-B/16'
|
|
|
|
|
elif clip_name == 'L14':
|
|
|
|
|
clip_name = 'ViT-L/14'
|
|
|
|
|
self.clip_model, self.clip_preprocess = clip.load(clip_name, jit = True)
|
|
|
|
|
|
|
|
|
|
for n, p in self.named_parameters():
|
|
|
|
|
if 'clip_model' in n:
|
|
|
|
|
p.requires_grad = False
|
|
|
|
|
|
|
|
|
|
self.loss_dists = DISTS()
|
|
|
|
|
# self.loss_ssim = MS_SSIM(data_range=1, size_average=True, channel=3)
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
|
args = self.args
|
|
|
|
|
@ -308,17 +330,25 @@ class RWKV_IMG(pl.LightningModule):
|
|
|
|
|
out = self(img)
|
|
|
|
|
if self.trainer.is_global_zero:
|
|
|
|
|
if (self.trainer.global_step + 1) % (100 * int(args.devices)) == 0:
|
|
|
|
|
img_dir = f"test/image_model/{args.run_name}"
|
|
|
|
|
if not os.path.exists(img_dir):
|
|
|
|
|
os.makedirs(img_dir)
|
|
|
|
|
vision.utils.save_image(
|
|
|
|
|
img[:4], f"test/image_model/{self.trainer.global_step}-src.jpg"
|
|
|
|
|
img[:4], f"{img_dir}/{self.trainer.global_step}-src.jpg"#, padding=0
|
|
|
|
|
)
|
|
|
|
|
vision.utils.save_image(
|
|
|
|
|
out[:4], f"test/image_model/{self.trainer.global_step}-out.jpg"
|
|
|
|
|
out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg"#, padding=0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
loss_l1 = F.l1_loss(out, img)
|
|
|
|
|
# loss_l1 = F.l1_loss(out, img)
|
|
|
|
|
# loss_ssim = 1 - self.loss_ssim(out, img)
|
|
|
|
|
# return loss_dists# * 1# + loss_l1 * 1 + # + loss_ssim * 0.4
|
|
|
|
|
|
|
|
|
|
loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True)
|
|
|
|
|
|
|
|
|
|
return loss_l1 + loss_dists
|
|
|
|
|
loss_clip = F.mse_loss(self.clip_model.encode_image(img), self.clip_model.encode_image(out))
|
|
|
|
|
|
|
|
|
|
return loss_dists + loss_clip * args.my_img_clip_scale
|
|
|
|
|
|
|
|
|
|
def training_step_end(self, batch_parts):
|
|
|
|
|
all = self.all_gather(batch_parts)
|
|
|
|
|
|