diff --git a/RWKV-v4neo/img_demoAE.py b/RWKV-v4neo/img_demoAE.py new file mode 100644 index 0000000..ab0d4ed --- /dev/null +++ b/RWKV-v4neo/img_demoAE.py @@ -0,0 +1,165 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import torch, types, os +import numpy as np +from PIL import Image +import torch.nn as nn +from torch.nn import functional as F +import torchvision as vision +import torchvision.transforms as transforms +np.set_printoptions(precision=4, suppress=True, linewidth=200) +print(f'loading...') + +######################################################################################################## + +model_prefix = 'test/image_trained/out-v7c_d8_256-224-13bit-OB32x0.5-201' +input_img = 'test/img_ae_test/test0.png' + +######################################################################################################## + +class ToBinary(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return torch.floor(x + 0.5) # no need for noise when we have plenty of data + + @staticmethod + def backward(ctx, grad_output): + return grad_output.clone() # pass-through + +class R_ENCODER(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + dd = 8 + self.Bxx = nn.BatchNorm2d(dd*64) + + self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1) + self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) + self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) + + self.B00 = nn.BatchNorm2d(dd*4) + self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) + self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) + self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) + self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) + + self.B10 = nn.BatchNorm2d(dd*16) + self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) + self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) + self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) + self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) + + self.B20 = nn.BatchNorm2d(dd*64) + self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) + self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) + self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + + self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1) + + def forward(self, img): + ACT = F.mish + + x = self.CIN(img) + xx = self.Bxx(F.pixel_unshuffle(x, 8)) + x = x + self.Cx1(ACT(self.Cx0(x))) + + x = F.pixel_unshuffle(x, 2) + x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) + x = x + self.C03(ACT(self.C02(x))) + + x = F.pixel_unshuffle(x, 2) + x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) + x = x + self.C13(ACT(self.C12(x))) + + x = F.pixel_unshuffle(x, 2) + x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) + x = x + self.C23(ACT(self.C22(x))) + + x = self.COUT(x + xx) + return torch.sigmoid(x) + +class R_DECODER(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + dd = 8 + self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1) + + self.B00 = nn.BatchNorm2d(dd*64) + self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) + self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) + self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + + self.B10 = nn.BatchNorm2d(dd*16) + self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) + self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) + self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) + self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) + + self.B20 = nn.BatchNorm2d(dd*4) + self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) + self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) + self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) + self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) + + self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) + self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) + self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1) + + def forward(self, code): + ACT = F.mish + x = self.CIN(code) + + x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) + x = x + self.C03(ACT(self.C02(x))) + x = F.pixel_shuffle(x, 2) + + x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) + x = x + self.C13(ACT(self.C12(x))) + x = F.pixel_shuffle(x, 2) + + x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) + x = x + self.C23(ACT(self.C22(x))) + x = F.pixel_shuffle(x, 2) + + x = x + self.Cx1(ACT(self.Cx0(x))) + x = self.COUT(x) + + return torch.sigmoid(x) + +######################################################################################################## + +print(f'building model...') +args = types.SimpleNamespace() +args.my_img_bit = 13 +encoder = R_ENCODER(args).eval().cuda() +decoder = R_DECODER(args).eval().cuda() + +zpow = torch.tensor([2**i for i in range(0,13)]).reshape(13,1,1).cuda().long() + +encoder.load_state_dict(torch.load(f'{model_prefix}-E.pth')) +decoder.load_state_dict(torch.load(f'{model_prefix}-D.pth')) + +######################################################################################################## + +print(f'test image...') +img_transform = transforms.Compose([ + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Resize((224, 224)) +]) + +with torch.no_grad(): + img = img_transform(Image.open(input_img)).unsqueeze(0).cuda() + z = encoder(img) + z = ToBinary.apply(z) + + zz = torch.sum(z.squeeze().long() * zpow, dim=0) + print(f'Code shape = {zz.shape}\n{zz.cpu().numpy()}\n') + + out = decoder(z) + vision.utils.save_image(out, f"{input_img.split('.')[0]}-out-13bit.jpg") diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 37a5f40..43c7f26 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -37,6 +37,12 @@ class MyDataset(Dataset): print("Current vocab size =", self.vocab_size, "(make sure it's correct)") self.data_size = len(self.data) print(f"Data has {self.data_size} tokens.") + elif args.data_type == "uint16": + self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len) + self.vocab_size = args.vocab_size + print("Current vocab size =", self.vocab_size, "(make sure it's correct)") + self.data_size = self.data.shape[0] + print(f"Data has {self.data_size} samples.") elif args.data_type == "wds_img": self.vocab_size = -1 self.data_size = -1 @@ -94,7 +100,7 @@ class MyDataset(Dataset): transforms.CenterCrop(512), transforms.Resize((args.my_img_size)) ]) - self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) + self.data_raw = wds.WebDataset(args.data_file, resampled=True, handler=wds.handlers.warn_and_continue).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) for pp in self.data_raw.pipeline: if 'Resampled' in str(pp): pp.deterministic = True @@ -110,27 +116,33 @@ class MyDataset(Dataset): # tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n") return dd[0], dd[2] else: - ctx_len = args.ctx_len - req_len = ctx_len + 1 - - if args.my_pile_stage > 0: - ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank - factor = (math.sqrt(5) - 1) / 2 - factor = int(args.magic_prime * factor) - i = ((factor * ii * ii * ii) % args.magic_prime) * ctx_len - i = i + args.my_pile_shift - # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") + if args.data_type == "uint16": + i = np.random.randint(0, self.data_size-1) + dix = self.data[i] + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) else: - # cheat: pick a random spot in dataset - i = np.random.randint(0, self.data_size - req_len) + ctx_len = args.ctx_len + req_len = ctx_len + 1 - if args.data_type == "binidx": - dix = self.data.get(idx=0, offset=i, length=req_len).astype(int) - elif args.data_type == "numpy": - dix = self.data[i : i + req_len] - else: - dix = [self.stoi[s] for s in self.data[i : i + req_len]] + if args.my_pile_stage > 0: + ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank + factor = (math.sqrt(5) - 1) / 2 + factor = int(args.magic_prime * factor) + i = ((factor * ii * ii * ii) % args.magic_prime) * ctx_len + i = i + args.my_pile_shift + # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") + else: + # cheat: pick a random spot in dataset + i = np.random.randint(0, self.data_size - req_len) + + if args.data_type == "binidx": + dix = self.data.get(idx=0, offset=i, length=req_len).astype(int) + elif args.data_type == "numpy": + dix = self.data[i : i + req_len] + else: + dix = [self.stoi[s] for s in self.data[i : i + req_len]] - x = torch.tensor(dix[:-1], dtype=torch.long) - y = torch.tensor(dix[1:], dtype=torch.long) + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) return x, y diff --git a/RWKV-v4neo/src/model_run.py b/RWKV-v4neo/src/model_run.py index 0d6499c..18be189 100644 --- a/RWKV-v4neo/src/model_run.py +++ b/RWKV-v4neo/src/model_run.py @@ -44,6 +44,8 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment w[x] = w[x].squeeze() if '.time_decay' in x: w[x] = -torch.exp(w[x]) + if 'pos_emb_x' in x: + self.w.pos_emb = (w['pos_emb_x'] + w['pos_emb_y']).reshape(ctx_len+1, -1)[:-1,:] if DEBUG_TIME and '.time_' in x: print(x, w[x].squeeze().cpu().numpy()) @@ -150,6 +152,11 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment with torch.no_grad(): w = self.w x = w.emb.weight[ctx[-1]] + try: + pos_emb = w.pos_emb[len(ctx)-1] + x = x + pos_emb + except: + pass for i in range(self.n_layer): if i == 0: diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 5d8adc4..0834559 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -92,6 +92,11 @@ if __name__ == "__main__": parser.add_argument("--my_img_l1_scale", default=0, type=float) parser.add_argument("--my_img_encoder", default='x', type=str) # parser.add_argument("--my_img_noise_scale", default=0, type=float) + parser.add_argument("--my_sample_len", default=0, type=int) + parser.add_argument("--my_ffn_shift", default=1, type=int) + parser.add_argument("--my_att_shift", default=1, type=int) + parser.add_argument("--my_pos_emb", default=0, type=int) + parser.add_argument("--load_partial", default=0, type=int) parser = Trainer.add_argparse_args(parser) args = parser.parse_args() @@ -108,7 +113,7 @@ if __name__ == "__main__": from pytorch_lightning.utilities import rank_zero_info, rank_zero_only if args.random_seed >= 0: - print(f"########## WARNING: GLOBAL SEED SET TO {args.random_seed} ##########\n" * 3) + print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3) seed_everything(args.random_seed) np.set_printoptions(precision=4, suppress=True, linewidth=200) @@ -210,7 +215,7 @@ if __name__ == "__main__": ) rank_zero_info(str(vars(args)) + "\n") - assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img"] + assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"] if args.lr_final == 0 or args.lr_init == 0: rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n") @@ -276,10 +281,11 @@ if __name__ == "__main__": print(f"Trying {args.load_model}") load_dict = torch.load(args.load_model, map_location="cpu") - # load_keys = load_dict.keys() - # for k in model.state_dict(): - # if k not in load_keys: - # load_dict[k] = model.state_dict()[k] + if args.load_partial == 1: + load_keys = load_dict.keys() + for k in model.state_dict(): + if k not in load_keys: + load_dict[k] = model.state_dict()[k] model.load_state_dict(load_dict) trainer = Trainer.from_argparse_args(