img models (wip)
parent
3a7e6a6aa3
commit
40e91dd1d7
@ -0,0 +1,195 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import os, math, gc
|
||||
import torchvision as vision
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
from pytorch_lightning.strategies import DeepSpeedStrategy
|
||||
import deepspeed
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
||||
from pytorch_msssim import SSIM
|
||||
|
||||
class To2Bin(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return torch.floor(x + torch.empty_like(x).uniform_(0, 1))
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.clone()
|
||||
|
||||
def __nop(ob):
|
||||
return ob
|
||||
|
||||
MyModule = nn.Module
|
||||
MyFunction = __nop
|
||||
if os.environ["RWKV_JIT_ON"] == "1":
|
||||
MyModule = torch.jit.ScriptModule
|
||||
MyFunction = torch.jit.script_method
|
||||
|
||||
class RWKV_IMG(pl.LightningModule):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.e0b0 = nn.BatchNorm2d(12)
|
||||
self.e0w0 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.e0b1 = nn.BatchNorm2d(12)
|
||||
self.e0w1 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.e0b2 = nn.BatchNorm2d(12)
|
||||
self.e0w2 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.e0b3 = nn.BatchNorm2d(12)
|
||||
self.e0w3 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
|
||||
|
||||
self.e1b0 = nn.BatchNorm2d(48)
|
||||
self.e1w0 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.e1b1 = nn.BatchNorm2d(48)
|
||||
self.e1w1 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.e1b2 = nn.BatchNorm2d(48)
|
||||
self.e1w2 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.e1b3 = nn.BatchNorm2d(48)
|
||||
self.e1w3 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
|
||||
|
||||
self.e2b0 = nn.BatchNorm2d(192)
|
||||
self.e2w0 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.e2b1 = nn.BatchNorm2d(192)
|
||||
self.e2w1 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.e2b2 = nn.BatchNorm2d(192)
|
||||
self.e2w2 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.e2b3 = nn.BatchNorm2d(192)
|
||||
self.e2w3 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
|
||||
|
||||
self.ewww = nn.Conv2d(192, 8, kernel_size = 3, stride = 1, padding = 1)
|
||||
|
||||
self.dwww = nn.Conv2d(8, 192, kernel_size = 3, stride = 1, padding = 1)
|
||||
|
||||
self.d0b0 = nn.BatchNorm2d(192)
|
||||
self.d0w0 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.d0b1 = nn.BatchNorm2d(192)
|
||||
self.d0w1 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.d0b2 = nn.BatchNorm2d(192)
|
||||
self.d0w2 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.d0b3 = nn.BatchNorm2d(192)
|
||||
self.d0w3 = nn.Conv2d(192, 192, kernel_size = 3, stride = 1, padding = 1)
|
||||
|
||||
self.d1b0 = nn.BatchNorm2d(48)
|
||||
self.d1w0 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.d1b1 = nn.BatchNorm2d(48)
|
||||
self.d1w1 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.d1b2 = nn.BatchNorm2d(48)
|
||||
self.d1w2 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.d1b3 = nn.BatchNorm2d(48)
|
||||
self.d1w3 = nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1)
|
||||
|
||||
self.d2b0 = nn.BatchNorm2d(12)
|
||||
self.d2w0 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.d2b1 = nn.BatchNorm2d(12)
|
||||
self.d2w1 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.d2b2 = nn.BatchNorm2d(12)
|
||||
self.d2w2 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.d2b3 = nn.BatchNorm2d(12)
|
||||
self.d2w3 = nn.Conv2d(12, 12, kernel_size = 3, stride = 1, padding = 1)
|
||||
|
||||
self.SSIM = SSIM(data_range=1, size_average=True, channel=3)
|
||||
|
||||
def configure_optimizers(self):
|
||||
args = self.args
|
||||
optim_groups = [
|
||||
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
|
||||
]
|
||||
if self.deepspeed_offload:
|
||||
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
|
||||
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
||||
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
||||
|
||||
@property
|
||||
def deepspeed_offload(self) -> bool:
|
||||
strategy = self.trainer.strategy
|
||||
if isinstance(strategy, DeepSpeedStrategy):
|
||||
config = strategy.config["zero_optimization"]
|
||||
return config.get("offload_optimizer") or config.get("offload_param")
|
||||
return False
|
||||
|
||||
def forward(self, img):
|
||||
x = img
|
||||
|
||||
x = F.pixel_unshuffle(x, 2)
|
||||
x = x + self.e0w1(F.mish(self.e0b1(self.e0w0(F.mish(self.e0b0(x))))))
|
||||
x = x + self.e0w3(F.mish(self.e0b3(self.e0w2(F.mish(self.e0b2(x))))))
|
||||
|
||||
x = F.pixel_unshuffle(x, 2)
|
||||
x = x + self.e1w1(F.mish(self.e1b1(self.e1w0(F.mish(self.e1b0(x))))))
|
||||
x = x + self.e1w3(F.mish(self.e1b3(self.e1w2(F.mish(self.e1b2(x))))))
|
||||
|
||||
x = F.pixel_unshuffle(x, 2)
|
||||
x = x + self.e2w1(F.mish(self.e2b1(self.e2w0(F.mish(self.e2b0(x))))))
|
||||
x = x + self.e2w3(F.mish(self.e2b3(self.e2w2(F.mish(self.e2b2(x))))))
|
||||
|
||||
x = self.ewww(x)
|
||||
|
||||
x = To2Bin.apply(torch.sigmoid(x))
|
||||
# print(x.shape, x)
|
||||
|
||||
x = self.dwww(x)
|
||||
|
||||
x = x + self.d0w1(F.mish(self.d0b1(self.d0w0(F.mish(self.d0b0(x))))))
|
||||
x = x + self.d0w3(F.mish(self.d0b3(self.d0w2(F.mish(self.d0b2(x))))))
|
||||
x = F.pixel_shuffle(x, 2)
|
||||
|
||||
x = x + self.d1w1(F.mish(self.d1b1(self.d1w0(F.mish(self.d1b0(x))))))
|
||||
x = x + self.d1w3(F.mish(self.d1b3(self.d1w2(F.mish(self.d1b2(x))))))
|
||||
x = F.pixel_shuffle(x, 2)
|
||||
|
||||
x = x + self.d2w1(F.mish(self.d2b1(self.d2w0(F.mish(self.d2b0(x))))))
|
||||
x = x + self.d2w3(F.mish(self.d2b3(self.d2w2(F.mish(self.d2b2(x))))))
|
||||
x = F.pixel_shuffle(x, 2)
|
||||
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
args = self.args
|
||||
img, txt = batch
|
||||
out = self(img)
|
||||
if self.trainer.is_global_zero:
|
||||
if (self.trainer.global_step+1) % (100 * int(args.devices)) == 0:
|
||||
vision.utils.save_image(img[:4], f"test/image_model/{self.trainer.global_step}-src.jpg")
|
||||
vision.utils.save_image(out[:4], f"test/image_model/{self.trainer.global_step}-out.jpg")
|
||||
|
||||
return 1 - self.SSIM(out.float(), img.float())
|
||||
|
||||
def training_step_end(self, batch_parts):
|
||||
all = self.all_gather(batch_parts)
|
||||
if self.trainer.is_global_zero:
|
||||
self.trainer.my_loss_all = all
|
||||
|
||||
def generate_init_weight(self):
|
||||
print(
|
||||
f"""
|
||||
############################################################################
|
||||
#
|
||||
# Init model weight (slow for large models)...
|
||||
#
|
||||
############################################################################
|
||||
"""
|
||||
)
|
||||
m = {}
|
||||
for n in self.state_dict():
|
||||
p = self.state_dict()[n]
|
||||
shape = p.shape
|
||||
|
||||
m[n] = p
|
||||
|
||||
m[n] = m[n].cpu()
|
||||
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
m[n] = m[n].half()
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
m[n] = m[n].bfloat16()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return m
|
||||
Loading…
Reference in New Issue