main
BlinkDL 3 years ago
parent a86b15a3da
commit 40237495cc

@ -47,6 +47,7 @@ class MyDataset(Dataset):
self.vocab_size = -1
self.data_size = -1
self.data = None
self.error_count = 0
else:
if args.data_type == "dummy":
print("Building dummy data...")
@ -88,7 +89,7 @@ class MyDataset(Dataset):
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
if args.data_type == "wds_img":
if self.data == None:
def init_wds(self, bias=0):
def identity(x):
return x
import webdataset as wds
@ -100,17 +101,28 @@ class MyDataset(Dataset):
transforms.CenterCrop(512),
transforms.Resize((args.my_img_size))
])
self.data_raw = wds.WebDataset(args.data_file, resampled=True, handler=wds.handlers.warn_and_continue).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
for pp in self.data_raw.pipeline:
if 'Resampled' in str(pp):
pp.deterministic = True
def worker_seed():
return rank*100000+epoch
return rank*100000+epoch+bias*1e9
pp.worker_seed = worker_seed
self.data = iter(self.data_raw)
# print(f"WebDataset loaded for rank {rank} epoch {epoch}")
if self.data == None:
init_wds(self)
trial = 0
while trial < 10:
try:
dd = next(self.data) # jpg, json, txt
break
except:
print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]')
self.error_count += 1
init_wds(self, self.error_count)
trial += 1
pass
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}")
# with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp:
# tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n")

@ -3,11 +3,24 @@
########################################################################################################
import numpy as np
import os
import os, math, gc
import torch
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
import torchvision as vision
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from pytorch_msssim import MS_SSIM
def __nop(ob):
return ob
MyModule = torch.jit.ScriptModule
# MyFunction = __nop
MyFunction = torch.jit.script_method
import clip
from transformers import CLIPModel
@ -39,7 +52,7 @@ class L2pooling(nn.Module):
class DISTS(torch.nn.Module):
def __init__(self, load_weights=True):
super(DISTS, self).__init__()
vgg_pretrained_features = models.vgg16(
vgg_pretrained_features = vision.models.vgg16(
weights="VGG16_Weights.IMAGENET1K_V1"
).features
self.stage1 = torch.nn.Sequential()
@ -134,22 +147,6 @@ class DISTS(torch.nn.Module):
else:
return score
########################################################################################################
import os, math, gc
import torchvision as vision
import torch
import torch.nn as nn
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from pytorch_msssim import MS_SSIM
class ToBinary(torch.autograd.Function):
@staticmethod
def forward(ctx, x):#, noise_scale):
@ -164,40 +161,51 @@ class ToBinary(torch.autograd.Function):
def backward(ctx, grad_output):
return grad_output.clone()#, None
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
########################################################################################################
class R_ENCODER(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.B00 = nn.BatchNorm2d(12)
self.C00 = nn.Conv2d(12, 96, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(96, 12, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(12, 96, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(96, 12, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(48)
self.C10 = nn.Conv2d(48, 192, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(192, 48, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(48, 192, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(192, 48, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(192)
self.C20 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(192, args.my_img_bit, kernel_size=3, padding=1)
dd = 8
self.Bxx = nn.BatchNorm2d(dd*64)
self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*4)
self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*64)
self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.B21 = nn.BatchNorm2d(dd*64)
# self.C24 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C25 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.C26 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C27 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
@MyFunction
def forward(self, x):
ACT = F.silu
def forward(self, img):
ACT = F.mish
x = self.CIN(img)
xx = self.Bxx(F.pixel_unshuffle(x, 8))
x = x + self.Cx1(ACT(self.Cx0(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
@ -210,9 +218,10 @@ class R_ENCODER(MyModule):
x = F.pixel_unshuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
# x = x + self.C25(ACT(self.C24(ACT(self.B21(x)))))
# x = x + self.C27(ACT(self.C26(x)))
x = self.COUT(x)
x = self.COUT(x + xx)
return torch.sigmoid(x)
########################################################################################################
@ -221,35 +230,45 @@ class R_DECODER(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.CIN = nn.Conv2d(args.my_img_bit, 192, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(192)
self.C00 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(192, 192, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(48)
self.C10 = nn.Conv2d(48, 192, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(192, 48, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(48, 192, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(192, 48, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(12)
self.C20 = nn.Conv2d(12, 96, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(96, 12, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(12, 96, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(96, 12, kernel_size=3, padding=1)
dd = 8
self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*64)
self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.B01 = nn.BatchNorm2d(dd*64)
# self.C04 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C05 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.C06 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C07 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*4)
self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
@MyFunction
def forward(self, x):
ACT = F.silu
x = self.CIN(x)
def forward(self, code):
ACT = F.mish
x = self.CIN(code)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
# x = x + self.C05(ACT(self.C04(ACT(self.B01(x)))))
# x = x + self.C07(ACT(self.C06(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
@ -260,9 +279,12 @@ class R_DECODER(MyModule):
x = x + self.C23(ACT(self.C22(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.Cx1(ACT(self.Cx0(x)))
x = self.COUT(x)
return torch.sigmoid(x)
########################################################################################################
########################################################################################################`
def cosine_loss(x, y):
x = F.normalize(x, dim=-1)
@ -292,10 +314,10 @@ class RWKV_IMG(pl.LightningModule):
if self.clip_model == None:
self.clip_model, _ = clip.load(clip_name, jit = True)
self.register_buffer(
"clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, -1, 1, 1)
"clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1)
)
self.register_buffer(
"clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, -1, 1, 1)
"clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1)
)
for n, p in self.named_parameters():
@ -393,8 +415,23 @@ class RWKV_IMG(pl.LightningModule):
)
m = {}
for n in self.state_dict():
scale = 1
p = self.state_dict()[n]
shape = p.shape
ss = n.split('.')
# if ss[0] in ['encoder', 'decoder']:
# if ss[2] == 'bias':
# scale = 0
# # elif n == 'encoder.CIN.weight':
# # nn.init.dirac_(p)
# else:
# try:
# if ss[1][0] == 'C' and (int(ss[1][2]) % 2 == 1):
# scale = 0
# except:
# pass
# m[n] = p * scale
m[n] = p

@ -18,9 +18,6 @@ class train_callback(pl.Callback):
# LR schedule
w_step = args.warmup_steps
if trainer.global_step < w_step:
lr = args.lr_init * (0.2 + 0.8 * trainer.global_step / w_step)
else:
if args.lr_final == args.lr_init or args.epoch_count == 0:
lr = args.lr_init
else:
@ -33,6 +30,9 @@ class train_callback(pl.Callback):
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
if trainer.global_step < w_step:
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
# if trainer.is_global_zero:
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)

Loading…
Cancel
Save