diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 43c7f26..a86d5a0 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -47,6 +47,7 @@ class MyDataset(Dataset): self.vocab_size = -1 self.data_size = -1 self.data = None + self.error_count = 0 else: if args.data_type == "dummy": print("Building dummy data...") @@ -88,7 +89,7 @@ class MyDataset(Dataset): # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}") if args.data_type == "wds_img": - if self.data == None: + def init_wds(self, bias=0): def identity(x): return x import webdataset as wds @@ -100,17 +101,28 @@ class MyDataset(Dataset): transforms.CenterCrop(512), transforms.Resize((args.my_img_size)) ]) - self.data_raw = wds.WebDataset(args.data_file, resampled=True, handler=wds.handlers.warn_and_continue).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) + self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) for pp in self.data_raw.pipeline: if 'Resampled' in str(pp): pp.deterministic = True def worker_seed(): - return rank*100000+epoch + return rank*100000+epoch+bias*1e9 pp.worker_seed = worker_seed self.data = iter(self.data_raw) # print(f"WebDataset loaded for rank {rank} epoch {epoch}") - - dd = next(self.data) # jpg, json, txt + if self.data == None: + init_wds(self) + trial = 0 + while trial < 10: + try: + dd = next(self.data) # jpg, json, txt + break + except: + print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]') + self.error_count += 1 + init_wds(self, self.error_count) + trial += 1 + pass # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}") # with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp: # tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n") diff --git a/RWKV-v4neo/src/model_img.py b/RWKV-v4neo/src/model_img.py index c46378c..2433723 100644 --- a/RWKV-v4neo/src/model_img.py +++ b/RWKV-v4neo/src/model_img.py @@ -3,11 +3,24 @@ ######################################################################################################## import numpy as np -import os +import os, math, gc import torch -from torchvision import models import torch.nn as nn import torch.nn.functional as F +import torchvision as vision +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only +from pytorch_lightning.strategies import DeepSpeedStrategy +import deepspeed +from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam +# from pytorch_msssim import MS_SSIM + +def __nop(ob): + return ob +MyModule = torch.jit.ScriptModule +# MyFunction = __nop +MyFunction = torch.jit.script_method + import clip from transformers import CLIPModel @@ -39,7 +52,7 @@ class L2pooling(nn.Module): class DISTS(torch.nn.Module): def __init__(self, load_weights=True): super(DISTS, self).__init__() - vgg_pretrained_features = models.vgg16( + vgg_pretrained_features = vision.models.vgg16( weights="VGG16_Weights.IMAGENET1K_V1" ).features self.stage1 = torch.nn.Sequential() @@ -134,39 +147,19 @@ class DISTS(torch.nn.Module): else: return score + class ToBinary(torch.autograd.Function): + @staticmethod + def forward(ctx, x):#, noise_scale): + # if noise_scale > 0: + # noise_min = 0.5 - noise_scale / 2 + # noise_max = 0.5 + noise_scale / 2 + # return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max)) + # else: + return torch.floor(x + 0.5) # no need for noise when we have plenty of data -######################################################################################################## - -import os, math, gc -import torchvision as vision -import torch -import torch.nn as nn -from torch.nn import functional as F -import pytorch_lightning as pl -from pytorch_lightning.utilities import rank_zero_info, rank_zero_only -from pytorch_lightning.strategies import DeepSpeedStrategy -import deepspeed -from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam -# from pytorch_msssim import MS_SSIM - - -class ToBinary(torch.autograd.Function): - @staticmethod - def forward(ctx, x):#, noise_scale): - # if noise_scale > 0: - # noise_min = 0.5 - noise_scale / 2 - # noise_max = 0.5 + noise_scale / 2 - # return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max)) - # else: - return torch.floor(x + 0.5) # no need for noise when we have plenty of data - - @staticmethod - def backward(ctx, grad_output): - return grad_output.clone()#, None - - -MyModule = torch.jit.ScriptModule -MyFunction = torch.jit.script_method + @staticmethod + def backward(ctx, grad_output): + return grad_output.clone()#, None ######################################################################################################## @@ -174,30 +167,45 @@ class R_ENCODER(MyModule): def __init__(self, args): super().__init__() self.args = args - - self.B00 = nn.BatchNorm2d(12) - self.C00 = nn.Conv2d(12, 96, kernel_size=3, padding=1) - self.C01 = nn.Conv2d(96, 12, kernel_size=3, padding=1) - self.C02 = nn.Conv2d(12, 96, kernel_size=3, padding=1) - self.C03 = nn.Conv2d(96, 12, kernel_size=3, padding=1) - - self.B10 = nn.BatchNorm2d(48) - self.C10 = nn.Conv2d(48, 192, kernel_size=3, padding=1) - self.C11 = nn.Conv2d(192, 48, kernel_size=3, padding=1) - self.C12 = nn.Conv2d(48, 192, kernel_size=3, padding=1) - self.C13 = nn.Conv2d(192, 48, kernel_size=3, padding=1) - - self.B20 = nn.BatchNorm2d(192) - self.C20 = nn.Conv2d(192, 192, kernel_size=3, padding=1) - self.C21 = nn.Conv2d(192, 192, kernel_size=3, padding=1) - self.C22 = nn.Conv2d(192, 192, kernel_size=3, padding=1) - self.C23 = nn.Conv2d(192, 192, kernel_size=3, padding=1) - - self.COUT = nn.Conv2d(192, args.my_img_bit, kernel_size=3, padding=1) + dd = 8 + self.Bxx = nn.BatchNorm2d(dd*64) + + self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1) + self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) + self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) + + self.B00 = nn.BatchNorm2d(dd*4) + self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) + self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) + self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) + self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) + + self.B10 = nn.BatchNorm2d(dd*16) + self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) + self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) + self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) + self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) + + self.B20 = nn.BatchNorm2d(dd*64) + self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) + self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) + self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + # self.B21 = nn.BatchNorm2d(dd*64) + # self.C24 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) + # self.C25 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + # self.C26 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) + # self.C27 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + + self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1) @MyFunction - def forward(self, x): - ACT = F.silu + def forward(self, img): + ACT = F.mish + + x = self.CIN(img) + xx = self.Bxx(F.pixel_unshuffle(x, 8)) + x = x + self.Cx1(ACT(self.Cx0(x))) x = F.pixel_unshuffle(x, 2) x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) @@ -210,9 +218,10 @@ class R_ENCODER(MyModule): x = F.pixel_unshuffle(x, 2) x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) x = x + self.C23(ACT(self.C22(x))) + # x = x + self.C25(ACT(self.C24(ACT(self.B21(x))))) + # x = x + self.C27(ACT(self.C26(x))) - x = self.COUT(x) - + x = self.COUT(x + xx) return torch.sigmoid(x) ######################################################################################################## @@ -221,35 +230,45 @@ class R_DECODER(MyModule): def __init__(self, args): super().__init__() self.args = args - - self.CIN = nn.Conv2d(args.my_img_bit, 192, kernel_size=3, padding=1) - - self.B00 = nn.BatchNorm2d(192) - self.C00 = nn.Conv2d(192, 192, kernel_size=3, padding=1) - self.C01 = nn.Conv2d(192, 192, kernel_size=3, padding=1) - self.C02 = nn.Conv2d(192, 192, kernel_size=3, padding=1) - self.C03 = nn.Conv2d(192, 192, kernel_size=3, padding=1) - - self.B10 = nn.BatchNorm2d(48) - self.C10 = nn.Conv2d(48, 192, kernel_size=3, padding=1) - self.C11 = nn.Conv2d(192, 48, kernel_size=3, padding=1) - self.C12 = nn.Conv2d(48, 192, kernel_size=3, padding=1) - self.C13 = nn.Conv2d(192, 48, kernel_size=3, padding=1) - - self.B20 = nn.BatchNorm2d(12) - self.C20 = nn.Conv2d(12, 96, kernel_size=3, padding=1) - self.C21 = nn.Conv2d(96, 12, kernel_size=3, padding=1) - self.C22 = nn.Conv2d(12, 96, kernel_size=3, padding=1) - self.C23 = nn.Conv2d(96, 12, kernel_size=3, padding=1) + dd = 8 + self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1) + + self.B00 = nn.BatchNorm2d(dd*64) + self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) + self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) + self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + # self.B01 = nn.BatchNorm2d(dd*64) + # self.C04 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) + # self.C05 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + # self.C06 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) + # self.C07 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + + self.B10 = nn.BatchNorm2d(dd*16) + self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) + self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) + self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) + self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) + + self.B20 = nn.BatchNorm2d(dd*4) + self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) + self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) + self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) + self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) + + self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) + self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) + self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1) @MyFunction - def forward(self, x): - ACT = F.silu - - x = self.CIN(x) + def forward(self, code): + ACT = F.mish + x = self.CIN(code) x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) x = x + self.C03(ACT(self.C02(x))) + # x = x + self.C05(ACT(self.C04(ACT(self.B01(x))))) + # x = x + self.C07(ACT(self.C06(x))) x = F.pixel_shuffle(x, 2) x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) @@ -260,9 +279,12 @@ class R_DECODER(MyModule): x = x + self.C23(ACT(self.C22(x))) x = F.pixel_shuffle(x, 2) + x = x + self.Cx1(ACT(self.Cx0(x))) + x = self.COUT(x) + return torch.sigmoid(x) -######################################################################################################## +########################################################################################################` def cosine_loss(x, y): x = F.normalize(x, dim=-1) @@ -292,10 +314,10 @@ class RWKV_IMG(pl.LightningModule): if self.clip_model == None: self.clip_model, _ = clip.load(clip_name, jit = True) self.register_buffer( - "clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, -1, 1, 1) + "clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1) ) self.register_buffer( - "clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, -1, 1, 1) + "clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1) ) for n, p in self.named_parameters(): @@ -393,8 +415,23 @@ class RWKV_IMG(pl.LightningModule): ) m = {} for n in self.state_dict(): + scale = 1 p = self.state_dict()[n] shape = p.shape + ss = n.split('.') + + # if ss[0] in ['encoder', 'decoder']: + # if ss[2] == 'bias': + # scale = 0 + # # elif n == 'encoder.CIN.weight': + # # nn.init.dirac_(p) + # else: + # try: + # if ss[1][0] == 'C' and (int(ss[1][2]) % 2 == 1): + # scale = 0 + # except: + # pass + # m[n] = p * scale m[n] = p diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index f32fb06..8a648eb 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -18,23 +18,23 @@ class train_callback(pl.Callback): # LR schedule w_step = args.warmup_steps - if trainer.global_step < w_step: - lr = args.lr_init * (0.2 + 0.8 * trainer.global_step / w_step) + if args.lr_final == args.lr_init or args.epoch_count == 0: + lr = args.lr_init else: - if args.lr_final == args.lr_init or args.epoch_count == 0: - lr = args.lr_init - else: - decay_step = real_step - args.my_pile_edecay * args.epoch_steps - decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps - progress = (decay_step - w_step + 1) / (decay_total - w_step) - progress = min(1, max(0, progress)) - - if args.lr_final == 0 or args.lr_init == 0: # linear decay - lr = args.lr_init + (args.lr_final - args.lr_init) * progress - else: # exp decay - lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) - # if trainer.is_global_zero: - # print(trainer.global_step, decay_step, decay_total, w_step, progress, lr) + decay_step = real_step - args.my_pile_edecay * args.epoch_steps + decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps + progress = (decay_step - w_step + 1) / (decay_total - w_step) + progress = min(1, max(0, progress)) + + if args.lr_final == 0 or args.lr_init == 0: # linear decay + lr = args.lr_init + (args.lr_final - args.lr_init) * progress + else: # exp decay + lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) + + if trainer.global_step < w_step: + lr = lr * (0.2 + 0.8 * trainer.global_step / w_step) + # if trainer.is_global_zero: + # print(trainer.global_step, decay_step, decay_total, w_step, progress, lr) for param_group in trainer.optimizers[0].param_groups: if args.layerwise_lr > 0: