img models (wip)

main
BlinkDL 3 years ago
parent 3a7e6a6aa3
commit 40e91dd1d7

@ -2,7 +2,7 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
######################################################################################################## ########################################################################################################
import json, math import json, math, random
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -37,6 +37,19 @@ class MyDataset(Dataset):
print("Current vocab size =", self.vocab_size, "(make sure it's correct)") print("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data) self.data_size = len(self.data)
print(f"Data has {self.data_size} tokens.") print(f"Data has {self.data_size} tokens.")
elif args.data_type == "wds_img":
def identity(x):
return x
import torchvision as vision
import webdataset as wds
import torchvision.transforms as transforms
img_transform = transforms.Compose(
[transforms.CenterCrop(256)]
)
self.data = iter(wds.WebDataset(args.data_file, resampled=True).shuffle(1000).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity).with_epoch(1000000))
print("WebDataset loaded.")
self.vocab_size = -1
self.data_size = -1
else: else:
if args.data_type == "dummy": if args.data_type == "dummy":
print("Building dummy data...") print("Building dummy data...")
@ -71,15 +84,17 @@ class MyDataset(Dataset):
return self.args.epoch_steps * self.args.micro_bsz return self.args.epoch_steps * self.args.micro_bsz
def __getitem__(self, idx): def __getitem__(self, idx):
#
# we are cheating: pick a random spot in dataset
#
args = self.args args = self.args
rank = self.global_rank rank = self.global_rank
epoch = self.real_epoch epoch = self.real_epoch
world_size = self.world_size world_size = self.world_size
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}") # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
if args.data_type == "wds_img":
dd = next(self.data) # jpg, json, txt
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}")
return dd[0], dd[2]
else:
ctx_len = args.ctx_len ctx_len = args.ctx_len
req_len = ctx_len + 1 req_len = ctx_len + 1
@ -91,6 +106,7 @@ class MyDataset(Dataset):
i = i + args.my_pile_shift i = i + args.my_pile_shift
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
else: else:
# cheat: pick a random spot in dataset
i = np.random.randint(0, self.data_size - req_len) i = np.random.randint(0, self.data_size - req_len)
if args.data_type == "binidx": if args.data_type == "binidx":

@ -3,7 +3,6 @@
######################################################################################################## ########################################################################################################
import os, math, gc import os, math, gc
from re import L
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F

@ -0,0 +1,195 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os, math, gc
import torchvision as vision
import torch
import torch.nn as nn
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from pytorch_msssim import SSIM
class To2Bin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.floor(x + torch.empty_like(x).uniform_(0, 1))
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone()
def __nop(ob):
return ob
MyModule = nn.Module
MyFunction = __nop
if os.environ["RWKV_JIT_ON"] == "1":
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
class RWKV_IMG(pl.LightningModule):
def __init__(self, args):
super().__init__()
self.args = args
self.e0b0 = nn.BatchNorm2d(12)
self.e0w0 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
self.e0b1 = nn.BatchNorm2d(12)
self.e0w1 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
self.e0b2 = nn.BatchNorm2d(12)
self.e0w2 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
self.e0b3 = nn.BatchNorm2d(12)
self.e0w3 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
self.e1b0 = nn.BatchNorm2d(48)
self.e1w0 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
self.e1b1 = nn.BatchNorm2d(48)
self.e1w1 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
self.e1b2 = nn.BatchNorm2d(48)
self.e1w2 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
self.e1b3 = nn.BatchNorm2d(48)
self.e1w3 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
self.e2b0 = nn.BatchNorm2d(192)
self.e2w0 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
self.e2b1 = nn.BatchNorm2d(192)
self.e2w1 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
self.e2b2 = nn.BatchNorm2d(192)
self.e2w2 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
self.e2b3 = nn.BatchNorm2d(192)
self.e2w3 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
self.ewww = nn.Conv2d(192, 8, kernel_size = 3, stride = 1, padding = 1)
self.dwww = nn.Conv2d(8, 192, kernel_size = 3, stride = 1, padding = 1)
self.d0b0 = nn.BatchNorm2d(192)
self.d0w0 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
self.d0b1 = nn.BatchNorm2d(192)
self.d0w1 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
self.d0b2 = nn.BatchNorm2d(192)
self.d0w2 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
self.d0b3 = nn.BatchNorm2d(192)
self.d0w3 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
self.d1b0 = nn.BatchNorm2d(48)
self.d1w0 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
self.d1b1 = nn.BatchNorm2d(48)
self.d1w1 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
self.d1b2 = nn.BatchNorm2d(48)
self.d1w2 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
self.d1b3 = nn.BatchNorm2d(48)
self.d1w3 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
self.d2b0 = nn.BatchNorm2d(12)
self.d2w0 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
self.d2b1 = nn.BatchNorm2d(12)
self.d2w1 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
self.d2b2 = nn.BatchNorm2d(12)
self.d2w2 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
self.d2b3 = nn.BatchNorm2d(12)
self.d2w3 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
self.SSIM = SSIM(data_range=1, size_average=True, channel=3)
def configure_optimizers(self):
args = self.args
optim_groups = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
]
if self.deepspeed_offload:
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
@property
def deepspeed_offload(self) -> bool:
strategy = self.trainer.strategy
if isinstance(strategy, DeepSpeedStrategy):
config = strategy.config["zero_optimization"]
return config.get("offload_optimizer") or config.get("offload_param")
return False
def forward(self, img):
x = img
x = F.pixel_unshuffle(x, 2)
x = x + self.e0w1(F.mish(self.e0b1(self.e0w0(F.mish(self.e0b0(x))))))
x = x + self.e0w3(F.mish(self.e0b3(self.e0w2(F.mish(self.e0b2(x))))))
x = F.pixel_unshuffle(x, 2)
x = x + self.e1w1(F.mish(self.e1b1(self.e1w0(F.mish(self.e1b0(x))))))
x = x + self.e1w3(F.mish(self.e1b3(self.e1w2(F.mish(self.e1b2(x))))))
x = F.pixel_unshuffle(x, 2)
x = x + self.e2w1(F.mish(self.e2b1(self.e2w0(F.mish(self.e2b0(x))))))
x = x + self.e2w3(F.mish(self.e2b3(self.e2w2(F.mish(self.e2b2(x))))))
x = self.ewww(x)
x = To2Bin.apply(torch.sigmoid(x))
# print(x.shape, x)
x = self.dwww(x)
x = x + self.d0w1(F.mish(self.d0b1(self.d0w0(F.mish(self.d0b0(x))))))
x = x + self.d0w3(F.mish(self.d0b3(self.d0w2(F.mish(self.d0b2(x))))))
x = F.pixel_shuffle(x, 2)
x = x + self.d1w1(F.mish(self.d1b1(self.d1w0(F.mish(self.d1b0(x))))))
x = x + self.d1w3(F.mish(self.d1b3(self.d1w2(F.mish(self.d1b2(x))))))
x = F.pixel_shuffle(x, 2)
x = x + self.d2w1(F.mish(self.d2b1(self.d2w0(F.mish(self.d2b0(x))))))
x = x + self.d2w3(F.mish(self.d2b3(self.d2w2(F.mish(self.d2b2(x))))))
x = F.pixel_shuffle(x, 2)
x = torch.sigmoid(x)
return x
def training_step(self, batch, batch_idx):
args = self.args
img, txt = batch
out = self(img)
if self.trainer.is_global_zero:
if (self.trainer.global_step+1) % (100 * int(args.devices)) == 0:
vision.utils.save_image(img[:4], f"test/image_model/{self.trainer.global_step}-src.jpg")
vision.utils.save_image(out[:4], f"test/image_model/{self.trainer.global_step}-out.jpg")
return 1 - self.SSIM(out.float(), img.float())
def training_step_end(self, batch_parts):
all = self.all_gather(batch_parts)
if self.trainer.is_global_zero:
self.trainer.my_loss_all = all
def generate_init_weight(self):
print(
f"""
############################################################################
#
# Init model weight (slow for large models)...
#
############################################################################
"""
)
m = {}
for n in self.state_dict():
p = self.state_dict()[n]
shape = p.shape
m[n] = p
m[n] = m[n].cpu()
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
m[n] = m[n].half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
m[n] = m[n].bfloat16()
gc.collect()
torch.cuda.empty_cache()
return m

@ -158,7 +158,7 @@ if __name__ == "__main__":
if args.my_pile_stage == 2: if args.my_pile_stage == 2:
args.warmup_steps = 10 args.warmup_steps = 10
else: else:
args.warmup_steps = 50 args.warmup_steps = 30
args.epoch_begin = max_p + 1 args.epoch_begin = max_p + 1
samples_per_epoch = args.epoch_steps * args.real_bsz samples_per_epoch = args.epoch_steps * args.real_bsz
@ -188,7 +188,7 @@ if __name__ == "__main__":
) )
rank_zero_info(str(vars(args)) + "\n") rank_zero_info(str(vars(args)) + "\n")
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy"] assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img"]
if args.lr_final == 0 or args.lr_init == 0: if args.lr_final == 0 or args.lr_init == 0:
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n") rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
@ -223,11 +223,15 @@ if __name__ == "__main__":
from src.trainer import train_callback, generate_init_weight from src.trainer import train_callback, generate_init_weight
from src.dataset import MyDataset from src.dataset import MyDataset
from src.model import RWKV
train_data = MyDataset(args) train_data = MyDataset(args)
args.vocab_size = train_data.vocab_size args.vocab_size = train_data.vocab_size
if args.data_type == 'wds_img':
from src.model_img import RWKV_IMG
model = RWKV_IMG(args)
else:
from src.model import RWKV
model = RWKV(args) model = RWKV(args)
if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights? if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
@ -250,6 +254,10 @@ if __name__ == "__main__":
print(f"Trying {args.load_model}") print(f"Trying {args.load_model}")
load_dict = torch.load(args.load_model, map_location="cpu") load_dict = torch.load(args.load_model, map_location="cpu")
# load_keys = load_dict.keys()
# for k in model.state_dict():
# if k not in load_keys:
# load_dict[k] = model.state_dict()[k]
model.load_state_dict(load_dict) model.load_state_dict(load_dict)
trainer = Trainer.from_argparse_args( trainer = Trainer.from_argparse_args(

Loading…
Cancel
Save