main
BlinkDL 3 years ago
parent 74fedc0d86
commit 9e325b88cb

@ -9,6 +9,7 @@ from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
import clip
from transformers import CLIPModel
class L2pooling(nn.Module):
def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
@ -151,12 +152,17 @@ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
class ToBinary(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.floor(x + torch.empty_like(x).uniform_(0.4, 0.6))
def forward(ctx, x):#, noise_scale):
# if noise_scale > 0:
# noise_min = 0.5 - noise_scale / 2
# noise_max = 0.5 + noise_scale / 2
# return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max))
# else:
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone()
return grad_output.clone()#, None
MyModule = torch.jit.ScriptModule
@ -258,6 +264,11 @@ class R_DECODER(MyModule):
########################################################################################################
def cosine_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return 1 - torch.einsum('ij,ij->i',[x,y])
class RWKV_IMG(pl.LightningModule):
def __init__(self, args):
super().__init__()
@ -266,6 +277,7 @@ class RWKV_IMG(pl.LightningModule):
self.encoder = R_ENCODER(args)
self.decoder = R_DECODER(args)
self.clip_model = None
clip_name = args.my_img_clip
if clip_name == 'B32':
clip_name = 'ViT-B/32'
@ -273,7 +285,18 @@ class RWKV_IMG(pl.LightningModule):
clip_name = 'ViT-B/16'
elif clip_name == 'L14':
clip_name = 'ViT-L/14'
self.clip_model, self.clip_preprocess = clip.load(clip_name, jit = True)
elif clip_name == 'OB32':
clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
self.clip_model = CLIPModel.from_pretrained(clip_name)
self.clip_model.encode_image = self.clip_model.get_image_features
if self.clip_model == None:
self.clip_model, _ = clip.load(clip_name, jit = True)
self.register_buffer(
"clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, -1, 1, 1)
)
self.register_buffer(
"clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, -1, 1, 1)
)
for n, p in self.named_parameters():
if 'clip_model' in n:
@ -320,7 +343,7 @@ class RWKV_IMG(pl.LightningModule):
def forward(self, img):
z = self.encoder(img)
z = ToBinary.apply(z)
z = ToBinary.apply(z)#, self.args.my_img_noise_scale)
out = self.decoder(z)
return out
@ -340,15 +363,18 @@ class RWKV_IMG(pl.LightningModule):
out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg"#, padding=0
)
# loss_l1 = F.l1_loss(out, img)
# loss_ssim = 1 - self.loss_ssim(out, img)
# return loss_dists# * 1# + loss_l1 * 1 + # + loss_ssim * 0.4
loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True)
loss_clip = F.mse_loss(self.clip_model.encode_image(img), self.clip_model.encode_image(out))
iii = self.clip_model.encode_image((img - self.clip_mean) / self.clip_std)
ooo = self.clip_model.encode_image((out - self.clip_mean) / self.clip_std)
loss_clip = torch.mean(cosine_loss(iii, ooo))
return loss_dists + loss_clip * args.my_img_clip_scale
if args.my_img_l1_scale > 0:
loss_l1 = F.l1_loss(out, img)
return loss_dists + loss_clip * args.my_img_clip_scale + loss_l1 * args.my_img_l1_scale
else:
return loss_dists + loss_clip * args.my_img_clip_scale
def training_step_end(self, batch_parts):
all = self.all_gather(batch_parts)

@ -111,8 +111,16 @@ class train_callback(pl.Callback):
args = self.args
if trainer.is_global_zero: # logging & save state_dict
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
if args.data_type == 'wds_img':
raw_dict = pl_module.state_dict()
to_save_dict = {}
for k in raw_dict:
if k.startswith('encoder.') or k.startswith('decoder.'):
to_save_dict[k] = raw_dict[k]
else:
to_save_dict = pl_module.state_dict()
torch.save(
pl_module.state_dict(),
to_save_dict,
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
)
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")

@ -84,11 +84,14 @@ if __name__ == "__main__":
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser.add_argument("--my_img_version", default=0, type=int)
parser.add_argument("--my_img_version", default=0, type=str)
parser.add_argument("--my_img_size", default=0, type=int)
parser.add_argument("--my_img_bit", default=0, type=int)
parser.add_argument("--my_img_clip", default='x', type=str)
parser.add_argument("--my_img_clip_scale", default=1, type=float)
parser.add_argument("--my_img_l1_scale", default=0, type=float)
parser.add_argument("--my_img_encoder", default='x', type=str)
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
@ -105,7 +108,7 @@ if __name__ == "__main__":
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED SET TO f{args.random_seed} ##########\n" * 3)
print(f"########## WARNING: GLOBAL SEED SET TO {args.random_seed} ##########\n" * 3)
seed_everything(args.random_seed)
np.set_printoptions(precision=4, suppress=True, linewidth=200)

Loading…
Cancel
Save