+AE model

main
BlinkDL 3 years ago
parent 9e325b88cb
commit 77bcaa8247

@ -0,0 +1,165 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import torch, types, os
import numpy as np
from PIL import Image
import torch.nn as nn
from torch.nn import functional as F
import torchvision as vision
import torchvision.transforms as transforms
np.set_printoptions(precision=4, suppress=True, linewidth=200)
print(f'loading...')
########################################################################################################
model_prefix = 'test/image_trained/out-v7c_d8_256-224-13bit-OB32x0.5-201'
input_img = 'test/img_ae_test/test0.png'
########################################################################################################
class ToBinary(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone() # pass-through
class R_ENCODER(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
dd = 8
self.Bxx = nn.BatchNorm2d(dd*64)
self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*4)
self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*64)
self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
def forward(self, img):
ACT = F.mish
x = self.CIN(img)
xx = self.Bxx(F.pixel_unshuffle(x, 8))
x = x + self.Cx1(ACT(self.Cx0(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
x = x + self.C13(ACT(self.C12(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
x = self.COUT(x + xx)
return torch.sigmoid(x)
class R_DECODER(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
dd = 8
self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*64)
self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*4)
self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
def forward(self, code):
ACT = F.mish
x = self.CIN(code)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
x = x + self.C13(ACT(self.C12(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.Cx1(ACT(self.Cx0(x)))
x = self.COUT(x)
return torch.sigmoid(x)
########################################################################################################
print(f'building model...')
args = types.SimpleNamespace()
args.my_img_bit = 13
encoder = R_ENCODER(args).eval().cuda()
decoder = R_DECODER(args).eval().cuda()
zpow = torch.tensor([2**i for i in range(0,13)]).reshape(13,1,1).cuda().long()
encoder.load_state_dict(torch.load(f'{model_prefix}-E.pth'))
decoder.load_state_dict(torch.load(f'{model_prefix}-D.pth'))
########################################################################################################
print(f'test image...')
img_transform = transforms.Compose([
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Resize((224, 224))
])
with torch.no_grad():
img = img_transform(Image.open(input_img)).unsqueeze(0).cuda()
z = encoder(img)
z = ToBinary.apply(z)
zz = torch.sum(z.squeeze().long() * zpow, dim=0)
print(f'Code shape = {zz.shape}\n{zz.cpu().numpy()}\n')
out = decoder(z)
vision.utils.save_image(out, f"{input_img.split('.')[0]}-out-13bit.jpg")

@ -37,6 +37,12 @@ class MyDataset(Dataset):
print("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data)
print(f"Data has {self.data_size} tokens.")
elif args.data_type == "uint16":
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
self.vocab_size = args.vocab_size
print("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = self.data.shape[0]
print(f"Data has {self.data_size} samples.")
elif args.data_type == "wds_img":
self.vocab_size = -1
self.data_size = -1
@ -94,7 +100,7 @@ class MyDataset(Dataset):
transforms.CenterCrop(512),
transforms.Resize((args.my_img_size))
])
self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
self.data_raw = wds.WebDataset(args.data_file, resampled=True, handler=wds.handlers.warn_and_continue).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
for pp in self.data_raw.pipeline:
if 'Resampled' in str(pp):
pp.deterministic = True
@ -109,6 +115,12 @@ class MyDataset(Dataset):
# with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp:
# tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n")
return dd[0], dd[2]
else:
if args.data_type == "uint16":
i = np.random.randint(0, self.data_size-1)
dix = self.data[i]
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
else:
ctx_len = args.ctx_len
req_len = ctx_len + 1

@ -44,6 +44,8 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = -torch.exp(w[x])
if 'pos_emb_x' in x:
self.w.pos_emb = (w['pos_emb_x'] + w['pos_emb_y']).reshape(ctx_len+1, -1)[:-1,:]
if DEBUG_TIME and '.time_' in x:
print(x, w[x].squeeze().cpu().numpy())
@ -150,6 +152,11 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment
with torch.no_grad():
w = self.w
x = w.emb.weight[ctx[-1]]
try:
pos_emb = w.pos_emb[len(ctx)-1]
x = x + pos_emb
except:
pass
for i in range(self.n_layer):
if i == 0:

@ -92,6 +92,11 @@ if __name__ == "__main__":
parser.add_argument("--my_img_l1_scale", default=0, type=float)
parser.add_argument("--my_img_encoder", default='x', type=str)
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser.add_argument("--my_sample_len", default=0, type=int)
parser.add_argument("--my_ffn_shift", default=1, type=int)
parser.add_argument("--my_att_shift", default=1, type=int)
parser.add_argument("--my_pos_emb", default=0, type=int)
parser.add_argument("--load_partial", default=0, type=int)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
@ -108,7 +113,7 @@ if __name__ == "__main__":
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED SET TO {args.random_seed} ##########\n" * 3)
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
seed_everything(args.random_seed)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
@ -210,7 +215,7 @@ if __name__ == "__main__":
)
rank_zero_info(str(vars(args)) + "\n")
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img"]
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
if args.lr_final == 0 or args.lr_init == 0:
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
@ -276,10 +281,11 @@ if __name__ == "__main__":
print(f"Trying {args.load_model}")
load_dict = torch.load(args.load_model, map_location="cpu")
# load_keys = load_dict.keys()
# for k in model.state_dict():
# if k not in load_keys:
# load_dict[k] = model.state_dict()[k]
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in model.state_dict():
if k not in load_keys:
load_dict[k] = model.state_dict()[k]
model.load_state_dict(load_dict)
trainer = Trainer.from_argparse_args(

Loading…
Cancel
Save