diff --git a/RWKV-v4neo/src/model_img.py b/RWKV-v4neo/src/model_img.py index b592103..c46378c 100644 --- a/RWKV-v4neo/src/model_img.py +++ b/RWKV-v4neo/src/model_img.py @@ -9,6 +9,7 @@ from torchvision import models import torch.nn as nn import torch.nn.functional as F import clip +from transformers import CLIPModel class L2pooling(nn.Module): def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0): @@ -151,12 +152,17 @@ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam class ToBinary(torch.autograd.Function): @staticmethod - def forward(ctx, x): - return torch.floor(x + torch.empty_like(x).uniform_(0.4, 0.6)) + def forward(ctx, x):#, noise_scale): + # if noise_scale > 0: + # noise_min = 0.5 - noise_scale / 2 + # noise_max = 0.5 + noise_scale / 2 + # return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max)) + # else: + return torch.floor(x + 0.5) # no need for noise when we have plenty of data @staticmethod def backward(ctx, grad_output): - return grad_output.clone() + return grad_output.clone()#, None MyModule = torch.jit.ScriptModule @@ -258,6 +264,11 @@ class R_DECODER(MyModule): ######################################################################################################## +def cosine_loss(x, y): + x = F.normalize(x, dim=-1) + y = F.normalize(y, dim=-1) + return 1 - torch.einsum('ij,ij->i',[x,y]) + class RWKV_IMG(pl.LightningModule): def __init__(self, args): super().__init__() @@ -266,6 +277,7 @@ class RWKV_IMG(pl.LightningModule): self.encoder = R_ENCODER(args) self.decoder = R_DECODER(args) + self.clip_model = None clip_name = args.my_img_clip if clip_name == 'B32': clip_name = 'ViT-B/32' @@ -273,7 +285,18 @@ class RWKV_IMG(pl.LightningModule): clip_name = 'ViT-B/16' elif clip_name == 'L14': clip_name = 'ViT-L/14' - self.clip_model, self.clip_preprocess = clip.load(clip_name, jit = True) + elif clip_name == 'OB32': + clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" + self.clip_model = CLIPModel.from_pretrained(clip_name) + self.clip_model.encode_image = self.clip_model.get_image_features + if self.clip_model == None: + self.clip_model, _ = clip.load(clip_name, jit = True) + self.register_buffer( + "clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, -1, 1, 1) + ) + self.register_buffer( + "clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, -1, 1, 1) + ) for n, p in self.named_parameters(): if 'clip_model' in n: @@ -320,7 +343,7 @@ class RWKV_IMG(pl.LightningModule): def forward(self, img): z = self.encoder(img) - z = ToBinary.apply(z) + z = ToBinary.apply(z)#, self.args.my_img_noise_scale) out = self.decoder(z) return out @@ -340,15 +363,18 @@ class RWKV_IMG(pl.LightningModule): out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg"#, padding=0 ) - # loss_l1 = F.l1_loss(out, img) # loss_ssim = 1 - self.loss_ssim(out, img) - # return loss_dists# * 1# + loss_l1 * 1 + # + loss_ssim * 0.4 - loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True) - loss_clip = F.mse_loss(self.clip_model.encode_image(img), self.clip_model.encode_image(out)) + iii = self.clip_model.encode_image((img - self.clip_mean) / self.clip_std) + ooo = self.clip_model.encode_image((out - self.clip_mean) / self.clip_std) + loss_clip = torch.mean(cosine_loss(iii, ooo)) - return loss_dists + loss_clip * args.my_img_clip_scale + if args.my_img_l1_scale > 0: + loss_l1 = F.l1_loss(out, img) + return loss_dists + loss_clip * args.my_img_clip_scale + loss_l1 * args.my_img_l1_scale + else: + return loss_dists + loss_clip * args.my_img_clip_scale def training_step_end(self, batch_parts): all = self.all_gather(batch_parts) diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index aed2c83..f32fb06 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -111,8 +111,16 @@ class train_callback(pl.Callback): args = self.args if trainer.is_global_zero: # logging & save state_dict if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1: + if args.data_type == 'wds_img': + raw_dict = pl_module.state_dict() + to_save_dict = {} + for k in raw_dict: + if k.startswith('encoder.') or k.startswith('decoder.'): + to_save_dict[k] = raw_dict[k] + else: + to_save_dict = pl_module.state_dict() torch.save( - pl_module.state_dict(), + to_save_dict, f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", ) trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 4423c33..5d8adc4 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -84,11 +84,14 @@ if __name__ == "__main__": parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) - parser.add_argument("--my_img_version", default=0, type=int) + parser.add_argument("--my_img_version", default=0, type=str) parser.add_argument("--my_img_size", default=0, type=int) parser.add_argument("--my_img_bit", default=0, type=int) parser.add_argument("--my_img_clip", default='x', type=str) parser.add_argument("--my_img_clip_scale", default=1, type=float) + parser.add_argument("--my_img_l1_scale", default=0, type=float) + parser.add_argument("--my_img_encoder", default='x', type=str) + # parser.add_argument("--my_img_noise_scale", default=0, type=float) parser = Trainer.add_argparse_args(parser) args = parser.parse_args() @@ -105,7 +108,7 @@ if __name__ == "__main__": from pytorch_lightning.utilities import rank_zero_info, rank_zero_only if args.random_seed >= 0: - print(f"########## WARNING: GLOBAL SEED SET TO f{args.random_seed} ##########\n" * 3) + print(f"########## WARNING: GLOBAL SEED SET TO {args.random_seed} ##########\n" * 3) seed_everything(args.random_seed) np.set_printoptions(precision=4, suppress=True, linewidth=200)