|
|
|
|
@ -35,7 +35,7 @@ class ClassEmbedder(nn.Module):
|
|
|
|
|
|
|
|
|
|
class TransformerEmbedder(AbstractEncoder):
|
|
|
|
|
"""Some transformer encoder layers"""
|
|
|
|
|
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
|
|
|
|
|
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda" if torch.cuda.is_available() else "cpu"):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.device = device
|
|
|
|
|
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
|
|
|
|
@ -52,7 +52,7 @@ class TransformerEmbedder(AbstractEncoder):
|
|
|
|
|
|
|
|
|
|
class BERTTokenizer(AbstractEncoder):
|
|
|
|
|
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
|
|
|
|
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
|
|
|
|
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", vq_interface=True, max_length=77):
|
|
|
|
|
super().__init__()
|
|
|
|
|
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
|
|
|
|
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
|
|
|
|
@ -80,7 +80,7 @@ class BERTTokenizer(AbstractEncoder):
|
|
|
|
|
class BERTEmbedder(AbstractEncoder):
|
|
|
|
|
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
|
|
|
|
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
|
|
|
|
|
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
|
|
|
|
|
device="cuda" if torch.cuda.is_available() else "cpu", use_tokenizer=True, embedding_dropout=0.0):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.use_tknz_fn = use_tokenizer
|
|
|
|
|
if self.use_tknz_fn:
|
|
|
|
|
@ -136,7 +136,7 @@ class SpatialRescaler(nn.Module):
|
|
|
|
|
|
|
|
|
|
class FrozenCLIPEmbedder(AbstractEncoder):
|
|
|
|
|
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
|
|
|
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
|
|
|
|
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda" if torch.cuda.is_available() else "cpu", max_length=77):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
|
|
|
|
self.transformer = CLIPTextModel.from_pretrained(version)
|
|
|
|
|
@ -231,4 +231,4 @@ class FrozenClipImageEmbedder(nn.Module):
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
from ldm.util import count_params
|
|
|
|
|
model = FrozenCLIPEmbedder()
|
|
|
|
|
count_params(model, verbose=True)
|
|
|
|
|
count_params(model, verbose=True)
|
|
|
|
|
|