main
BlinkDL 3 years ago
parent a86b15a3da
commit 40237495cc

@ -47,6 +47,7 @@ class MyDataset(Dataset):
self.vocab_size = -1 self.vocab_size = -1
self.data_size = -1 self.data_size = -1
self.data = None self.data = None
self.error_count = 0
else: else:
if args.data_type == "dummy": if args.data_type == "dummy":
print("Building dummy data...") print("Building dummy data...")
@ -88,7 +89,7 @@ class MyDataset(Dataset):
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}") # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
if args.data_type == "wds_img": if args.data_type == "wds_img":
if self.data == None: def init_wds(self, bias=0):
def identity(x): def identity(x):
return x return x
import webdataset as wds import webdataset as wds
@ -100,17 +101,28 @@ class MyDataset(Dataset):
transforms.CenterCrop(512), transforms.CenterCrop(512),
transforms.Resize((args.my_img_size)) transforms.Resize((args.my_img_size))
]) ])
self.data_raw = wds.WebDataset(args.data_file, resampled=True, handler=wds.handlers.warn_and_continue).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
for pp in self.data_raw.pipeline: for pp in self.data_raw.pipeline:
if 'Resampled' in str(pp): if 'Resampled' in str(pp):
pp.deterministic = True pp.deterministic = True
def worker_seed(): def worker_seed():
return rank*100000+epoch return rank*100000+epoch+bias*1e9
pp.worker_seed = worker_seed pp.worker_seed = worker_seed
self.data = iter(self.data_raw) self.data = iter(self.data_raw)
# print(f"WebDataset loaded for rank {rank} epoch {epoch}") # print(f"WebDataset loaded for rank {rank} epoch {epoch}")
if self.data == None:
dd = next(self.data) # jpg, json, txt init_wds(self)
trial = 0
while trial < 10:
try:
dd = next(self.data) # jpg, json, txt
break
except:
print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]')
self.error_count += 1
init_wds(self, self.error_count)
trial += 1
pass
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}") # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}")
# with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp: # with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp:
# tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n") # tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n")

@ -3,11 +3,24 @@
######################################################################################################## ########################################################################################################
import numpy as np import numpy as np
import os import os, math, gc
import torch import torch
from torchvision import models
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision as vision
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from pytorch_msssim import MS_SSIM
def __nop(ob):
return ob
MyModule = torch.jit.ScriptModule
# MyFunction = __nop
MyFunction = torch.jit.script_method
import clip import clip
from transformers import CLIPModel from transformers import CLIPModel
@ -39,7 +52,7 @@ class L2pooling(nn.Module):
class DISTS(torch.nn.Module): class DISTS(torch.nn.Module):
def __init__(self, load_weights=True): def __init__(self, load_weights=True):
super(DISTS, self).__init__() super(DISTS, self).__init__()
vgg_pretrained_features = models.vgg16( vgg_pretrained_features = vision.models.vgg16(
weights="VGG16_Weights.IMAGENET1K_V1" weights="VGG16_Weights.IMAGENET1K_V1"
).features ).features
self.stage1 = torch.nn.Sequential() self.stage1 = torch.nn.Sequential()
@ -134,39 +147,19 @@ class DISTS(torch.nn.Module):
else: else:
return score return score
class ToBinary(torch.autograd.Function):
@staticmethod
def forward(ctx, x):#, noise_scale):
# if noise_scale > 0:
# noise_min = 0.5 - noise_scale / 2
# noise_max = 0.5 + noise_scale / 2
# return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max))
# else:
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
######################################################################################################## @staticmethod
def backward(ctx, grad_output):
import os, math, gc return grad_output.clone()#, None
import torchvision as vision
import torch
import torch.nn as nn
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from pytorch_msssim import MS_SSIM
class ToBinary(torch.autograd.Function):
@staticmethod
def forward(ctx, x):#, noise_scale):
# if noise_scale > 0:
# noise_min = 0.5 - noise_scale / 2
# noise_max = 0.5 + noise_scale / 2
# return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max))
# else:
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone()#, None
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
######################################################################################################## ########################################################################################################
@ -174,30 +167,45 @@ class R_ENCODER(MyModule):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.args = args self.args = args
dd = 8
self.B00 = nn.BatchNorm2d(12) self.Bxx = nn.BatchNorm2d(dd*64)
self.C00 = nn.Conv2d(12, 96, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(96, 12, kernel_size=3, padding=1) self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(12, 96, kernel_size=3, padding=1) self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(96, 12, kernel_size=3, padding=1) self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(48) self.B00 = nn.BatchNorm2d(dd*4)
self.C10 = nn.Conv2d(48, 192, kernel_size=3, padding=1) self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(192, 48, kernel_size=3, padding=1) self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(48, 192, kernel_size=3, padding=1) self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(192, 48, kernel_size=3, padding=1) self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(192) self.B10 = nn.BatchNorm2d(dd*16)
self.C20 = nn.Conv2d(192, 192, kernel_size=3, padding=1) self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(192, 192, kernel_size=3, padding=1) self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(192, 192, kernel_size=3, padding=1) self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(192, 192, kernel_size=3, padding=1) self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(192, args.my_img_bit, kernel_size=3, padding=1) self.B20 = nn.BatchNorm2d(dd*64)
self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.B21 = nn.BatchNorm2d(dd*64)
# self.C24 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C25 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.C26 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C27 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
@MyFunction @MyFunction
def forward(self, x): def forward(self, img):
ACT = F.silu ACT = F.mish
x = self.CIN(img)
xx = self.Bxx(F.pixel_unshuffle(x, 8))
x = x + self.Cx1(ACT(self.Cx0(x)))
x = F.pixel_unshuffle(x, 2) x = F.pixel_unshuffle(x, 2)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
@ -210,9 +218,10 @@ class R_ENCODER(MyModule):
x = F.pixel_unshuffle(x, 2) x = F.pixel_unshuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x))) x = x + self.C23(ACT(self.C22(x)))
# x = x + self.C25(ACT(self.C24(ACT(self.B21(x)))))
# x = x + self.C27(ACT(self.C26(x)))
x = self.COUT(x) x = self.COUT(x + xx)
return torch.sigmoid(x) return torch.sigmoid(x)
######################################################################################################## ########################################################################################################
@ -221,35 +230,45 @@ class R_DECODER(MyModule):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.args = args self.args = args
dd = 8
self.CIN = nn.Conv2d(args.my_img_bit, 192, kernel_size=3, padding=1) self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(192) self.B00 = nn.BatchNorm2d(dd*64)
self.C00 = nn.Conv2d(192, 192, kernel_size=3, padding=1) self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(192, 192, kernel_size=3, padding=1) self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(192, 192, kernel_size=3, padding=1) self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(192, 192, kernel_size=3, padding=1) self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.B01 = nn.BatchNorm2d(dd*64)
self.B10 = nn.BatchNorm2d(48) # self.C04 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C10 = nn.Conv2d(48, 192, kernel_size=3, padding=1) # self.C05 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(192, 48, kernel_size=3, padding=1) # self.C06 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(48, 192, kernel_size=3, padding=1) # self.C07 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(192, 48, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.B20 = nn.BatchNorm2d(12) self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C20 = nn.Conv2d(12, 96, kernel_size=3, padding=1) self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(96, 12, kernel_size=3, padding=1) self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(12, 96, kernel_size=3, padding=1) self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(96, 12, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*4)
self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
@MyFunction @MyFunction
def forward(self, x): def forward(self, code):
ACT = F.silu ACT = F.mish
x = self.CIN(code)
x = self.CIN(x)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x))) x = x + self.C03(ACT(self.C02(x)))
# x = x + self.C05(ACT(self.C04(ACT(self.B01(x)))))
# x = x + self.C07(ACT(self.C06(x)))
x = F.pixel_shuffle(x, 2) x = F.pixel_shuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
@ -260,9 +279,12 @@ class R_DECODER(MyModule):
x = x + self.C23(ACT(self.C22(x))) x = x + self.C23(ACT(self.C22(x)))
x = F.pixel_shuffle(x, 2) x = F.pixel_shuffle(x, 2)
x = x + self.Cx1(ACT(self.Cx0(x)))
x = self.COUT(x)
return torch.sigmoid(x) return torch.sigmoid(x)
######################################################################################################## ########################################################################################################`
def cosine_loss(x, y): def cosine_loss(x, y):
x = F.normalize(x, dim=-1) x = F.normalize(x, dim=-1)
@ -292,10 +314,10 @@ class RWKV_IMG(pl.LightningModule):
if self.clip_model == None: if self.clip_model == None:
self.clip_model, _ = clip.load(clip_name, jit = True) self.clip_model, _ = clip.load(clip_name, jit = True)
self.register_buffer( self.register_buffer(
"clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, -1, 1, 1) "clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1)
) )
self.register_buffer( self.register_buffer(
"clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, -1, 1, 1) "clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1)
) )
for n, p in self.named_parameters(): for n, p in self.named_parameters():
@ -393,8 +415,23 @@ class RWKV_IMG(pl.LightningModule):
) )
m = {} m = {}
for n in self.state_dict(): for n in self.state_dict():
scale = 1
p = self.state_dict()[n] p = self.state_dict()[n]
shape = p.shape shape = p.shape
ss = n.split('.')
# if ss[0] in ['encoder', 'decoder']:
# if ss[2] == 'bias':
# scale = 0
# # elif n == 'encoder.CIN.weight':
# # nn.init.dirac_(p)
# else:
# try:
# if ss[1][0] == 'C' and (int(ss[1][2]) % 2 == 1):
# scale = 0
# except:
# pass
# m[n] = p * scale
m[n] = p m[n] = p

@ -18,23 +18,23 @@ class train_callback(pl.Callback):
# LR schedule # LR schedule
w_step = args.warmup_steps w_step = args.warmup_steps
if trainer.global_step < w_step: if args.lr_final == args.lr_init or args.epoch_count == 0:
lr = args.lr_init * (0.2 + 0.8 * trainer.global_step / w_step) lr = args.lr_init
else: else:
if args.lr_final == args.lr_init or args.epoch_count == 0: decay_step = real_step - args.my_pile_edecay * args.epoch_steps
lr = args.lr_init decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
else: progress = (decay_step - w_step + 1) / (decay_total - w_step)
decay_step = real_step - args.my_pile_edecay * args.epoch_steps progress = min(1, max(0, progress))
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
progress = (decay_step - w_step + 1) / (decay_total - w_step) if args.lr_final == 0 or args.lr_init == 0: # linear decay
progress = min(1, max(0, progress)) lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay
if args.lr_final == 0 or args.lr_init == 0: # linear decay lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay if trainer.global_step < w_step:
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
# if trainer.is_global_zero: # if trainer.is_global_zero:
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr) # print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
for param_group in trainer.optimizers[0].param_groups: for param_group in trainer.optimizers[0].param_groups:
if args.layerwise_lr > 0: if args.layerwise_lr > 0:

Loading…
Cancel
Save