Add files via upload

Code changes from https://github.com/CompVis/latent-diffusion/pull/123
main
ModeratePrawn 3 years ago committed by GitHub
parent 69ae4b35e0
commit feae820740
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,7 +18,7 @@ class DDIMSampler(object):
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"): if attr.device != torch.device("cuda") and torch.cuda.is_available():
attr = attr.to(torch.device("cuda")) attr = attr.to(torch.device("cuda"))
setattr(self, name, attr) setattr(self, name, attr)

@ -17,7 +17,7 @@ class PLMSSampler(object):
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"): if attr.device != torch.device("cuda") and torch.cuda.is_available():
attr = attr.to(torch.device("cuda")) attr = attr.to(torch.device("cuda"))
setattr(self, name, attr) setattr(self, name, attr)

@ -35,7 +35,7 @@ class ClassEmbedder(nn.Module):
class TransformerEmbedder(AbstractEncoder): class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers""" """Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda" if torch.cuda.is_available() else "cpu"):
super().__init__() super().__init__()
self.device = device self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
@ -52,7 +52,7 @@ class TransformerEmbedder(AbstractEncoder):
class BERTTokenizer(AbstractEncoder): class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77): def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", vq_interface=True, max_length=77):
super().__init__() super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
@ -80,7 +80,7 @@ class BERTTokenizer(AbstractEncoder):
class BERTEmbedder(AbstractEncoder): class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers""" """Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0): device="cuda" if torch.cuda.is_available() else "cpu", use_tokenizer=True, embedding_dropout=0.0):
super().__init__() super().__init__()
self.use_tknz_fn = use_tokenizer self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn: if self.use_tknz_fn:
@ -136,7 +136,7 @@ class SpatialRescaler(nn.Module):
class FrozenCLIPEmbedder(AbstractEncoder): class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" """Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): def __init__(self, version="openai/clip-vit-large-patch14", device="cuda" if torch.cuda.is_available() else "cpu", max_length=77):
super().__init__() super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version) self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version)

@ -44,6 +44,7 @@ def load_model_from_config(config, ckpt):
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
if torch.cuda.is_available():
model.cuda() model.cuda()
model.eval() model.eval()
return {"model": model}, global_step return {"model": model}, global_step
@ -117,6 +118,7 @@ def get_cond(mode, selected_path):
c = rearrange(c, '1 c h w -> 1 h w c') c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1. c = 2. * c - 1.
if torch.cuda.is_available():
c = c.to(torch.device("cuda")) c = c.to(torch.device("cuda"))
example["LR_image"] = c example["LR_image"] = c
example["image"] = c_up example["image"] = c_up

@ -53,6 +53,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:") print("unexpected keys:")
print(u) print(u)
if torch.cuda.is_available():
model.cuda() model.cuda()
model.eval() model.eval()
return model return model
@ -358,7 +359,10 @@ if __name__ == "__main__":
uc = None uc = None
if searcher is not None: if searcher is not None:
nn_dict = searcher(c, opt.knn) nn_dict = searcher(c, opt.knn)
c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1) nn_embeddings = torch.from_numpy(nn_dict['nn_embeddings'])
if torch.cuda.is_available():
nn_embeddings = nn_embeddings.cuda()
c = torch.cat([c, nn_embeddings], dim=1)
if opt.scale != 1.0: if opt.scale != 1.0:
uc = torch.zeros_like(c) uc = torch.zeros_like(c)
if isinstance(prompts, tuple): if isinstance(prompts, tuple):

@ -220,6 +220,7 @@ def get_parser():
def load_model_from_config(config, sd): def load_model_from_config(config, sd):
model = instantiate_from_config(config) model = instantiate_from_config(config)
model.load_state_dict(sd,strict=False) model.load_state_dict(sd,strict=False)
if torch.cuda.is_available():
model.cuda() model.cuda()
model.eval() model.eval()
return model return model

@ -60,6 +60,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:") print("unexpected keys:")
print(u) print(u)
if torch.cuda.is_available():
model.cuda() model.cuda()
model.eval() model.eval()
return model return model

Loading…
Cancel
Save