stable diffusion
@ -0,0 +1,140 @@
|
|||||||
|
# Stable Diffusion v1 Model Card
|
||||||
|
This model card focuses on the model associated with the Stable Diffusion model, available [here](https://github.com/CompVis/stable-diffusion).
|
||||||
|
|
||||||
|
## Model Details
|
||||||
|
- **Developed by:** Robin Rombach, Patrick Esser
|
||||||
|
- **Model type:** Diffusion-based text-to-image generation model
|
||||||
|
- **Language(s):** English
|
||||||
|
- **License:** [Proprietary](LICENSE)
|
||||||
|
- **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487).
|
||||||
|
- **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752).
|
||||||
|
- **Cite as:**
|
||||||
|
|
||||||
|
@InProceedings{Rombach_2022_CVPR,
|
||||||
|
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
|
||||||
|
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
|
||||||
|
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||||
|
month = {June},
|
||||||
|
year = {2022},
|
||||||
|
pages = {10684-10695}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Uses
|
||||||
|
|
||||||
|
## Direct Use
|
||||||
|
The model is intended for research purposes only. Possible research areas and
|
||||||
|
tasks include
|
||||||
|
|
||||||
|
- Safe deployment of models which have the potential to generate harmful content.
|
||||||
|
- Probing and understanding the limitations and biases of generative models.
|
||||||
|
- Generation of artworks and use in design and other artistic processes.
|
||||||
|
- Applications in educational or creative tools.
|
||||||
|
- Research on generative models.
|
||||||
|
|
||||||
|
Excluded uses are described below.
|
||||||
|
|
||||||
|
### Misuse, Malicious Use, and Out-of-Scope Use
|
||||||
|
_Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_.
|
||||||
|
|
||||||
|
|
||||||
|
The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
|
||||||
|
#### Out-of-Scope Use
|
||||||
|
The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
||||||
|
#### Misuse and Malicious Use
|
||||||
|
Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
|
||||||
|
|
||||||
|
- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
|
||||||
|
- Intentionally promoting or propagating discriminatory content or harmful stereotypes.
|
||||||
|
- Impersonating individuals without their consent.
|
||||||
|
- Sexual content without consent of the people who might see it.
|
||||||
|
- Mis- and disinformation
|
||||||
|
- Representations of egregious violence and gore
|
||||||
|
- Sharing of copyrighted or licensed material in violation of its terms of use.
|
||||||
|
- Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
|
||||||
|
|
||||||
|
## Limitations and Bias
|
||||||
|
|
||||||
|
### Limitations
|
||||||
|
|
||||||
|
- The model does not achieve perfect photorealism
|
||||||
|
- The model cannot render legible text
|
||||||
|
- The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
|
||||||
|
- Faces and people in general may not be generated properly.
|
||||||
|
- The model was trained mainly with English captions and will not work as well in other languages.
|
||||||
|
- The autoencoding part of the model is lossy
|
||||||
|
- The model was trained on a large-scale dataset
|
||||||
|
[LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
|
||||||
|
and is not fit for product use without additional safety mechanisms and
|
||||||
|
considerations.
|
||||||
|
|
||||||
|
### Bias
|
||||||
|
While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
|
||||||
|
Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
|
||||||
|
which consists of images that are primarily limited to English descriptions.
|
||||||
|
Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
|
||||||
|
This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
|
||||||
|
ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
|
||||||
|
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
**Training Data**
|
||||||
|
The model developers used the following dataset for training the model:
|
||||||
|
|
||||||
|
- LAION-2B (en) and subsets thereof (see next section)
|
||||||
|
|
||||||
|
**Training Procedure**
|
||||||
|
Stable Diffusion v1 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
|
||||||
|
|
||||||
|
- Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
|
||||||
|
- Text prompts are encoded through a ViT-L/14 text-encoder.
|
||||||
|
- The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
|
||||||
|
- The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet.
|
||||||
|
|
||||||
|
We currently provide three checkpoints, `sd-v1-1.ckpt`, `sd-v1-2.ckpt` and `sd-v1-3.ckpt`,
|
||||||
|
which were trained as follows,
|
||||||
|
|
||||||
|
- `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
|
||||||
|
194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
|
||||||
|
- `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`.
|
||||||
|
515k steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
|
||||||
|
filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
|
||||||
|
- `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-improved-aesthetics" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
|
||||||
|
|
||||||
|
|
||||||
|
- **Hardware:** 32 x 8 x A100 GPUs
|
||||||
|
- **Optimizer:** AdamW
|
||||||
|
- **Gradient Accumulations**: 2
|
||||||
|
- **Batch:** 32 x 8 x 2 x 4 = 2048
|
||||||
|
- **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
|
||||||
|
|
||||||
|
## Evaluation Results
|
||||||
|
Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
|
||||||
|
5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling
|
||||||
|
steps show the relative improvements of the checkpoints:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
|
||||||
|
## Environmental Impact
|
||||||
|
|
||||||
|
**Stable Diffusion v1** **Estimated Emissions**
|
||||||
|
Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
|
||||||
|
|
||||||
|
- **Hardware Type:** A100 PCIe 40GB
|
||||||
|
- **Hours used:** 150000
|
||||||
|
- **Cloud Provider:** AWS
|
||||||
|
- **Compute Region:** US-east
|
||||||
|
- **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq.
|
||||||
|
## Citation
|
||||||
|
@InProceedings{Rombach_2022_CVPR,
|
||||||
|
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
|
||||||
|
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
|
||||||
|
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||||
|
month = {June},
|
||||||
|
year = {2022},
|
||||||
|
pages = {10684-10695}
|
||||||
|
}
|
||||||
|
|
||||||
|
*This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
|
||||||
|
|
||||||
|
After Width: | Height: | Size: 651 KiB |
|
After Width: | Height: | Size: 596 KiB |
|
After Width: | Height: | Size: 609 KiB |
|
After Width: | Height: | Size: 548 KiB |
|
After Width: | Height: | Size: 705 KiB |
|
After Width: | Height: | Size: 757 KiB |
|
After Width: | Height: | Size: 612 KiB |
|
After Width: | Height: | Size: 319 KiB |
|
After Width: | Height: | Size: 610 KiB |
|
After Width: | Height: | Size: 643 KiB |
|
After Width: | Height: | Size: 641 KiB |
|
After Width: | Height: | Size: 174 KiB |
|
After Width: | Height: | Size: 1.1 MiB |
|
After Width: | Height: | Size: 1.3 MiB |
|
After Width: | Height: | Size: 945 KiB |
|
After Width: | Height: | Size: 972 KiB |
|
After Width: | Height: | Size: 2.5 MiB |
|
After Width: | Height: | Size: 2.5 MiB |
|
After Width: | Height: | Size: 2.3 MiB |
|
After Width: | Height: | Size: 662 KiB |
|
After Width: | Height: | Size: 302 KiB |
|
After Width: | Height: | Size: 2.2 MiB |
|
After Width: | Height: | Size: 70 KiB |
@ -0,0 +1,68 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 0.0001
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: class_label
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 192
|
||||||
|
attention_resolutions:
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 5
|
||||||
|
num_heads: 1
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 512
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.ClassEmbedder
|
||||||
|
params:
|
||||||
|
n_classes: 1001
|
||||||
|
embed_dim: 512
|
||||||
|
key: class_label
|
||||||
@ -0,0 +1,71 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.012
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: caption
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions:
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1280
|
||||||
|
use_checkpoint: true
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 1280
|
||||||
|
n_layer: 32
|
||||||
@ -0,0 +1,68 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 0.0001
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: jpg
|
||||||
|
cond_stage_key: nix
|
||||||
|
image_size: 48
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_by_std: false
|
||||||
|
scale_factor: 0.22765929
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 48
|
||||||
|
in_channels: 16
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 448
|
||||||
|
attention_resolutions:
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 4
|
||||||
|
use_scale_shift_norm: false
|
||||||
|
resblock_updown: false
|
||||||
|
num_head_channels: 32
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: true
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: val/rec_loss
|
||||||
|
embed_dim: 16
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions:
|
||||||
|
- 16
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
@ -0,0 +1,70 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
@ -0,0 +1,236 @@
|
|||||||
|
"""SAMPLING ONLY."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||||
|
|
||||||
|
|
||||||
|
class PLMSSampler(object):
|
||||||
|
def __init__(self, model, schedule="linear", **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
|
self.schedule = schedule
|
||||||
|
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
if type(attr) == torch.Tensor:
|
||||||
|
if attr.device != torch.device("cuda"):
|
||||||
|
attr = attr.to(torch.device("cuda"))
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
if ddim_eta != 0:
|
||||||
|
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||||
|
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||||
|
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||||
|
alphas_cumprod = self.model.alphas_cumprod
|
||||||
|
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||||
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||||
|
|
||||||
|
self.register_buffer('betas', to_torch(self.model.betas))
|
||||||
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||||
|
|
||||||
|
# ddim sampling parameters
|
||||||
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||||
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
|
eta=ddim_eta,verbose=verbose)
|
||||||
|
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||||
|
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||||
|
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||||
|
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||||
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
|
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||||
|
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||||
|
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for PLMS sampling is {size}')
|
||||||
|
|
||||||
|
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask, x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
)
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plms_sampling(self, cond, shape,
|
||||||
|
x_T=None, ddim_use_original_steps=False,
|
||||||
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||||
|
device = self.model.betas.device
|
||||||
|
b = shape[0]
|
||||||
|
if x_T is None:
|
||||||
|
img = torch.randn(shape, device=device)
|
||||||
|
else:
|
||||||
|
img = x_T
|
||||||
|
|
||||||
|
if timesteps is None:
|
||||||
|
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
|
elif timesteps is not None and not ddim_use_original_steps:
|
||||||
|
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||||
|
timesteps = self.ddim_timesteps[:subset_end]
|
||||||
|
|
||||||
|
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||||
|
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||||
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||||
|
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
|
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||||
|
old_eps = []
|
||||||
|
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||||
|
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
assert x0 is not None
|
||||||
|
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||||
|
img = img_orig * mask + (1. - mask) * img
|
||||||
|
|
||||||
|
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||||
|
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||||
|
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
old_eps=old_eps, t_next=ts_next)
|
||||||
|
img, pred_x0, e_t = outs
|
||||||
|
old_eps.append(e_t)
|
||||||
|
if len(old_eps) >= 4:
|
||||||
|
old_eps.pop(0)
|
||||||
|
if callback: callback(i)
|
||||||
|
if img_callback: img_callback(pred_x0, i)
|
||||||
|
|
||||||
|
if index % log_every_t == 0 or index == total_steps - 1:
|
||||||
|
intermediates['x_inter'].append(img)
|
||||||
|
intermediates['pred_x0'].append(pred_x0)
|
||||||
|
|
||||||
|
return img, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
def get_model_output(x, t):
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
|
||||||
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
e_t = get_model_output(x, t)
|
||||||
|
if len(old_eps) == 0:
|
||||||
|
# Pseudo Improved Euler (2nd order)
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||||
|
e_t_next = get_model_output(x_prev, t_next)
|
||||||
|
e_t_prime = (e_t + e_t_next) / 2
|
||||||
|
elif len(old_eps) == 1:
|
||||||
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||||
|
elif len(old_eps) == 2:
|
||||||
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||||
|
elif len(old_eps) >= 3:
|
||||||
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||||
|
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
|
return x_prev, pred_x0, e_t
|
||||||
@ -0,0 +1,293 @@
|
|||||||
|
"""make variations of input image"""
|
||||||
|
|
||||||
|
import argparse, os, sys, glob
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
from itertools import islice
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
from torch import autocast
|
||||||
|
from contextlib import nullcontext
|
||||||
|
import time
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
|
||||||
|
|
||||||
|
def chunk(it, size):
|
||||||
|
it = iter(it)
|
||||||
|
return iter(lambda: tuple(islice(it, size)), ())
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_config(config, ckpt, verbose=False):
|
||||||
|
print(f"Loading model from {ckpt}")
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
if "global_step" in pl_sd:
|
||||||
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
if len(m) > 0 and verbose:
|
||||||
|
print("missing keys:")
|
||||||
|
print(m)
|
||||||
|
if len(u) > 0 and verbose:
|
||||||
|
print("unexpected keys:")
|
||||||
|
print(u)
|
||||||
|
|
||||||
|
model.cuda()
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_img(path):
|
||||||
|
image = Image.open(path).convert("RGB")
|
||||||
|
w, h = image.size
|
||||||
|
print(f"loaded input image of size ({w}, {h}) from {path}")
|
||||||
|
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||||
|
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
return 2.*image - 1.
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
default="a painting of a virus monster playing guitar",
|
||||||
|
help="the prompt to render"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--init-img",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="path to the input image"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--outdir",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="dir to write results to",
|
||||||
|
default="outputs/img2img-samples"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_grid",
|
||||||
|
action='store_true',
|
||||||
|
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_save",
|
||||||
|
action='store_true',
|
||||||
|
help="do not save indiviual samples. For speed measurements.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddim_steps",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="number of ddim sampling steps",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--plms",
|
||||||
|
action='store_true',
|
||||||
|
help="use plms sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fixed_code",
|
||||||
|
action='store_true',
|
||||||
|
help="if enabled, uses the same starting code across all samples ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddim_eta",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_iter",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="sample this often",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--C",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="latent channels",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--f",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="downsampling factor, most often 8 or 16",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_samples",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="how many samples to produce for each given prompt. A.k.a batch size",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_rows",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="rows in the grid (default: n_samples)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scale",
|
||||||
|
type=float,
|
||||||
|
default=5.0,
|
||||||
|
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--strength",
|
||||||
|
type=float,
|
||||||
|
default=0.75,
|
||||||
|
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--from-file",
|
||||||
|
type=str,
|
||||||
|
help="if specified, load prompts from this file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default="configs/stable-diffusion/v1-inference.yaml",
|
||||||
|
help="path to config which constructs model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ckpt",
|
||||||
|
type=str,
|
||||||
|
default="models/ldm/stable-diffusion-v1/model.ckpt",
|
||||||
|
help="path to checkpoint of model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="the seed (for reproducible sampling)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--precision",
|
||||||
|
type=str,
|
||||||
|
help="evaluate at this precision",
|
||||||
|
choices=["full", "autocast"],
|
||||||
|
default="autocast"
|
||||||
|
)
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
seed_everything(opt.seed)
|
||||||
|
|
||||||
|
config = OmegaConf.load(f"{opt.config}")
|
||||||
|
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||||
|
|
||||||
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
if opt.plms:
|
||||||
|
raise NotImplementedError("PLMS sampler not (yet) supported")
|
||||||
|
sampler = PLMSSampler(model)
|
||||||
|
else:
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
|
||||||
|
os.makedirs(opt.outdir, exist_ok=True)
|
||||||
|
outpath = opt.outdir
|
||||||
|
|
||||||
|
batch_size = opt.n_samples
|
||||||
|
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||||
|
if not opt.from_file:
|
||||||
|
prompt = opt.prompt
|
||||||
|
assert prompt is not None
|
||||||
|
data = [batch_size * [prompt]]
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"reading prompts from {opt.from_file}")
|
||||||
|
with open(opt.from_file, "r") as f:
|
||||||
|
data = f.read().splitlines()
|
||||||
|
data = list(chunk(data, batch_size))
|
||||||
|
|
||||||
|
sample_path = os.path.join(outpath, "samples")
|
||||||
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
|
base_count = len(os.listdir(sample_path))
|
||||||
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
|
|
||||||
|
assert os.path.isfile(opt.init_img)
|
||||||
|
init_image = load_img(opt.init_img).to(device)
|
||||||
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||||
|
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
||||||
|
|
||||||
|
sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
|
||||||
|
|
||||||
|
assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
|
t_enc = int(opt.strength * opt.ddim_steps)
|
||||||
|
print(f"target t_enc is {t_enc} steps")
|
||||||
|
|
||||||
|
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
||||||
|
with torch.no_grad():
|
||||||
|
with precision_scope("cuda"):
|
||||||
|
with model.ema_scope():
|
||||||
|
tic = time.time()
|
||||||
|
all_samples = list()
|
||||||
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
|
for prompts in tqdm(data, desc="data"):
|
||||||
|
uc = None
|
||||||
|
if opt.scale != 1.0:
|
||||||
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
|
if isinstance(prompts, tuple):
|
||||||
|
prompts = list(prompts)
|
||||||
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
|
# encode (scaled latent)
|
||||||
|
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
|
||||||
|
# decode it
|
||||||
|
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
|
||||||
|
unconditional_conditioning=uc,)
|
||||||
|
|
||||||
|
x_samples = model.decode_first_stage(samples)
|
||||||
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
if not opt.skip_save:
|
||||||
|
for x_sample in x_samples:
|
||||||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||||
|
os.path.join(sample_path, f"{base_count:05}.png"))
|
||||||
|
base_count += 1
|
||||||
|
all_samples.append(x_samples)
|
||||||
|
|
||||||
|
if not opt.skip_grid:
|
||||||
|
# additionally, save as grid
|
||||||
|
grid = torch.stack(all_samples, 0)
|
||||||
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
|
grid = make_grid(grid, nrow=n_rows)
|
||||||
|
|
||||||
|
# to image
|
||||||
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||||
|
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||||
|
grid_count += 1
|
||||||
|
|
||||||
|
toc = time.time()
|
||||||
|
|
||||||
|
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
||||||
|
f" \nEnjoy.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -0,0 +1,398 @@
|
|||||||
|
import argparse, os, sys, glob
|
||||||
|
import clip
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
from itertools import islice
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
import scann
|
||||||
|
import time
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config, parallel_data_prefetch
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
|
||||||
|
|
||||||
|
DATABASES = [
|
||||||
|
"openimages",
|
||||||
|
"artbench-art_nouveau",
|
||||||
|
"artbench-baroque",
|
||||||
|
"artbench-expressionism",
|
||||||
|
"artbench-impressionism",
|
||||||
|
"artbench-post_impressionism",
|
||||||
|
"artbench-realism",
|
||||||
|
"artbench-romanticism",
|
||||||
|
"artbench-renaissance",
|
||||||
|
"artbench-surrealism",
|
||||||
|
"artbench-ukiyo_e",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def chunk(it, size):
|
||||||
|
it = iter(it)
|
||||||
|
return iter(lambda: tuple(islice(it, size)), ())
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_config(config, ckpt, verbose=False):
|
||||||
|
print(f"Loading model from {ckpt}")
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
if "global_step" in pl_sd:
|
||||||
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
if len(m) > 0 and verbose:
|
||||||
|
print("missing keys:")
|
||||||
|
print(m)
|
||||||
|
if len(u) > 0 and verbose:
|
||||||
|
print("unexpected keys:")
|
||||||
|
print(u)
|
||||||
|
|
||||||
|
model.cuda()
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class Searcher(object):
|
||||||
|
def __init__(self, database, retriever_version='ViT-L/14'):
|
||||||
|
assert database in DATABASES
|
||||||
|
# self.database = self.load_database(database)
|
||||||
|
self.database_name = database
|
||||||
|
self.searcher_savedir = f'data/rdm/searchers/{self.database_name}'
|
||||||
|
self.database_path = f'data/rdm/retrieval_databases/{self.database_name}'
|
||||||
|
self.retriever = self.load_retriever(version=retriever_version)
|
||||||
|
self.database = {'embedding': [],
|
||||||
|
'img_id': [],
|
||||||
|
'patch_coords': []}
|
||||||
|
self.load_database()
|
||||||
|
self.load_searcher()
|
||||||
|
|
||||||
|
def train_searcher(self, k,
|
||||||
|
metric='dot_product',
|
||||||
|
searcher_savedir=None):
|
||||||
|
|
||||||
|
print('Start training searcher')
|
||||||
|
searcher = scann.scann_ops_pybind.builder(self.database['embedding'] /
|
||||||
|
np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis],
|
||||||
|
k, metric)
|
||||||
|
self.searcher = searcher.score_brute_force().build()
|
||||||
|
print('Finish training searcher')
|
||||||
|
|
||||||
|
if searcher_savedir is not None:
|
||||||
|
print(f'Save trained searcher under "{searcher_savedir}"')
|
||||||
|
os.makedirs(searcher_savedir, exist_ok=True)
|
||||||
|
self.searcher.serialize(searcher_savedir)
|
||||||
|
|
||||||
|
def load_single_file(self, saved_embeddings):
|
||||||
|
compressed = np.load(saved_embeddings)
|
||||||
|
self.database = {key: compressed[key] for key in compressed.files}
|
||||||
|
print('Finished loading of clip embeddings.')
|
||||||
|
|
||||||
|
def load_multi_files(self, data_archive):
|
||||||
|
out_data = {key: [] for key in self.database}
|
||||||
|
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
|
||||||
|
for key in d.files:
|
||||||
|
out_data[key].append(d[key])
|
||||||
|
|
||||||
|
return out_data
|
||||||
|
|
||||||
|
def load_database(self):
|
||||||
|
|
||||||
|
print(f'Load saved patch embedding from "{self.database_path}"')
|
||||||
|
file_content = glob.glob(os.path.join(self.database_path, '*.npz'))
|
||||||
|
|
||||||
|
if len(file_content) == 1:
|
||||||
|
self.load_single_file(file_content[0])
|
||||||
|
elif len(file_content) > 1:
|
||||||
|
data = [np.load(f) for f in file_content]
|
||||||
|
prefetched_data = parallel_data_prefetch(self.load_multi_files, data,
|
||||||
|
n_proc=min(len(data), cpu_count()), target_data_type='dict')
|
||||||
|
|
||||||
|
self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in
|
||||||
|
self.database}
|
||||||
|
else:
|
||||||
|
raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
|
||||||
|
|
||||||
|
print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
|
||||||
|
|
||||||
|
def load_retriever(self, version='ViT-L/14', ):
|
||||||
|
model = FrozenClipImageEmbedder(model=version)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
model.cuda()
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def load_searcher(self):
|
||||||
|
print(f'load searcher for database {self.database_name} from {self.searcher_savedir}')
|
||||||
|
self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
|
||||||
|
print('Finished loading searcher.')
|
||||||
|
|
||||||
|
def search(self, x, k):
|
||||||
|
if self.searcher is None and self.database['embedding'].shape[0] < 2e4:
|
||||||
|
self.train_searcher(k) # quickly fit searcher on the fly for small databases
|
||||||
|
assert self.searcher is not None, 'Cannot search with uninitialized searcher'
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
x = x.detach().cpu().numpy()
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[:, 0]
|
||||||
|
query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
out_embeddings = self.database['embedding'][nns]
|
||||||
|
out_img_ids = self.database['img_id'][nns]
|
||||||
|
out_pc = self.database['patch_coords'][nns]
|
||||||
|
|
||||||
|
out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
|
||||||
|
'img_ids': out_img_ids,
|
||||||
|
'patch_coords': out_pc,
|
||||||
|
'queries': x,
|
||||||
|
'exec_time': end - start,
|
||||||
|
'nns': nns,
|
||||||
|
'q_embeddings': query_embeddings}
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __call__(self, x, n):
|
||||||
|
return self.search(x, n)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
|
||||||
|
# TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
default="a painting of a virus monster playing guitar",
|
||||||
|
help="the prompt to render"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--outdir",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="dir to write results to",
|
||||||
|
default="outputs/txt2img-samples"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_grid",
|
||||||
|
action='store_true',
|
||||||
|
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddim_steps",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="number of ddim sampling steps",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_repeat",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="number of repeats in CLIP latent space",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--plms",
|
||||||
|
action='store_true',
|
||||||
|
help="use plms sampling",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddim_eta",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_iter",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="sample this often",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--H",
|
||||||
|
type=int,
|
||||||
|
default=768,
|
||||||
|
help="image height, in pixel space",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--W",
|
||||||
|
type=int,
|
||||||
|
default=768,
|
||||||
|
help="image width, in pixel space",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_samples",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="how many samples to produce for each given prompt. A.k.a batch size",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_rows",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="rows in the grid (default: n_samples)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--scale",
|
||||||
|
type=float,
|
||||||
|
default=5.0,
|
||||||
|
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--from-file",
|
||||||
|
type=str,
|
||||||
|
help="if specified, load prompts from this file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default="configs/retrieval-augmented-diffusion/768x768.yaml",
|
||||||
|
help="path to config which constructs model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ckpt",
|
||||||
|
type=str,
|
||||||
|
default="models/rdm/rdm768x768/model.ckpt",
|
||||||
|
help="path to checkpoint of model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--clip_type",
|
||||||
|
type=str,
|
||||||
|
default="ViT-L/14",
|
||||||
|
help="which CLIP model to use for retrieval and NN encoding",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--database",
|
||||||
|
type=str,
|
||||||
|
default='artbench-surrealism',
|
||||||
|
choices=DATABASES,
|
||||||
|
help="The database used for the search, only applied when --use_neighbors=True",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_neighbors",
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help="Include neighbors in addition to text prompt for conditioning",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--knn",
|
||||||
|
default=10,
|
||||||
|
type=int,
|
||||||
|
help="The number of included neighbors, only applied when --use_neighbors=True",
|
||||||
|
)
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
config = OmegaConf.load(f"{opt.config}")
|
||||||
|
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||||
|
|
||||||
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)
|
||||||
|
|
||||||
|
if opt.plms:
|
||||||
|
sampler = PLMSSampler(model)
|
||||||
|
else:
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
|
||||||
|
os.makedirs(opt.outdir, exist_ok=True)
|
||||||
|
outpath = opt.outdir
|
||||||
|
|
||||||
|
batch_size = opt.n_samples
|
||||||
|
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||||
|
if not opt.from_file:
|
||||||
|
prompt = opt.prompt
|
||||||
|
assert prompt is not None
|
||||||
|
data = [batch_size * [prompt]]
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"reading prompts from {opt.from_file}")
|
||||||
|
with open(opt.from_file, "r") as f:
|
||||||
|
data = f.read().splitlines()
|
||||||
|
data = list(chunk(data, batch_size))
|
||||||
|
|
||||||
|
sample_path = os.path.join(outpath, "samples")
|
||||||
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
|
base_count = len(os.listdir(sample_path))
|
||||||
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
|
|
||||||
|
print(f"sampling scale for cfg is {opt.scale:.2f}")
|
||||||
|
|
||||||
|
searcher = None
|
||||||
|
if opt.use_neighbors:
|
||||||
|
searcher = Searcher(opt.database)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with model.ema_scope():
|
||||||
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
|
all_samples = list()
|
||||||
|
for prompts in tqdm(data, desc="data"):
|
||||||
|
print("sampling prompts:", prompts)
|
||||||
|
if isinstance(prompts, tuple):
|
||||||
|
prompts = list(prompts)
|
||||||
|
c = clip_text_encoder.encode(prompts)
|
||||||
|
uc = None
|
||||||
|
if searcher is not None:
|
||||||
|
nn_dict = searcher(c, opt.knn)
|
||||||
|
c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
|
||||||
|
if opt.scale != 1.0:
|
||||||
|
uc = torch.zeros_like(c)
|
||||||
|
if isinstance(prompts, tuple):
|
||||||
|
prompts = list(prompts)
|
||||||
|
shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
|
||||||
|
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=c.shape[0],
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=opt.scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=opt.ddim_eta,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
for x_sample in x_samples_ddim:
|
||||||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||||
|
os.path.join(sample_path, f"{base_count:05}.png"))
|
||||||
|
base_count += 1
|
||||||
|
all_samples.append(x_samples_ddim)
|
||||||
|
|
||||||
|
if not opt.skip_grid:
|
||||||
|
# additionally, save as grid
|
||||||
|
grid = torch.stack(all_samples, 0)
|
||||||
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
|
grid = make_grid(grid, nrow=n_rows)
|
||||||
|
|
||||||
|
# to image
|
||||||
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||||
|
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||||
|
grid_count += 1
|
||||||
|
|
||||||
|
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
|
||||||
@ -0,0 +1,147 @@
|
|||||||
|
import os, sys
|
||||||
|
import numpy as np
|
||||||
|
import scann
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ldm.util import parallel_data_prefetch
|
||||||
|
|
||||||
|
|
||||||
|
def search_bruteforce(searcher):
|
||||||
|
return searcher.score_brute_force().build()
|
||||||
|
|
||||||
|
|
||||||
|
def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
|
||||||
|
partioning_trainsize, num_leaves, num_leaves_to_search):
|
||||||
|
return searcher.tree(num_leaves=num_leaves,
|
||||||
|
num_leaves_to_search=num_leaves_to_search,
|
||||||
|
training_sample_size=partioning_trainsize). \
|
||||||
|
score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
|
||||||
|
|
||||||
|
|
||||||
|
def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
|
||||||
|
return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
|
||||||
|
reorder_k).build()
|
||||||
|
|
||||||
|
def load_datapool(dpath):
|
||||||
|
|
||||||
|
|
||||||
|
def load_single_file(saved_embeddings):
|
||||||
|
compressed = np.load(saved_embeddings)
|
||||||
|
database = {key: compressed[key] for key in compressed.files}
|
||||||
|
return database
|
||||||
|
|
||||||
|
def load_multi_files(data_archive):
|
||||||
|
database = {key: [] for key in data_archive[0].files}
|
||||||
|
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
|
||||||
|
for key in d.files:
|
||||||
|
database[key].append(d[key])
|
||||||
|
|
||||||
|
return database
|
||||||
|
|
||||||
|
print(f'Load saved patch embedding from "{dpath}"')
|
||||||
|
file_content = glob.glob(os.path.join(dpath, '*.npz'))
|
||||||
|
|
||||||
|
if len(file_content) == 1:
|
||||||
|
data_pool = load_single_file(file_content[0])
|
||||||
|
elif len(file_content) > 1:
|
||||||
|
data = [np.load(f) for f in file_content]
|
||||||
|
prefetched_data = parallel_data_prefetch(load_multi_files, data,
|
||||||
|
n_proc=min(len(data), cpu_count()), target_data_type='dict')
|
||||||
|
|
||||||
|
data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
|
||||||
|
else:
|
||||||
|
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
|
||||||
|
|
||||||
|
print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
|
||||||
|
return data_pool
|
||||||
|
|
||||||
|
|
||||||
|
def train_searcher(opt,
|
||||||
|
metric='dot_product',
|
||||||
|
partioning_trainsize=None,
|
||||||
|
reorder_k=None,
|
||||||
|
# todo tune
|
||||||
|
aiq_thld=0.2,
|
||||||
|
dims_per_block=2,
|
||||||
|
num_leaves=None,
|
||||||
|
num_leaves_to_search=None,):
|
||||||
|
|
||||||
|
data_pool = load_datapool(opt.database)
|
||||||
|
k = opt.knn
|
||||||
|
|
||||||
|
if not reorder_k:
|
||||||
|
reorder_k = 2 * k
|
||||||
|
|
||||||
|
# normalize
|
||||||
|
# embeddings =
|
||||||
|
searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
|
||||||
|
pool_size = data_pool['embedding'].shape[0]
|
||||||
|
|
||||||
|
print(*(['#'] * 100))
|
||||||
|
print('Initializing scaNN searcher with the following values:')
|
||||||
|
print(f'k: {k}')
|
||||||
|
print(f'metric: {metric}')
|
||||||
|
print(f'reorder_k: {reorder_k}')
|
||||||
|
print(f'anisotropic_quantization_threshold: {aiq_thld}')
|
||||||
|
print(f'dims_per_block: {dims_per_block}')
|
||||||
|
print(*(['#'] * 100))
|
||||||
|
print('Start training searcher....')
|
||||||
|
print(f'N samples in pool is {pool_size}')
|
||||||
|
|
||||||
|
# this reflects the recommended design choices proposed at
|
||||||
|
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
|
||||||
|
if pool_size < 2e4:
|
||||||
|
print('Using brute force search.')
|
||||||
|
searcher = search_bruteforce(searcher)
|
||||||
|
elif 2e4 <= pool_size and pool_size < 1e5:
|
||||||
|
print('Using asymmetric hashing search and reordering.')
|
||||||
|
searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
|
||||||
|
else:
|
||||||
|
print('Using using partioning, asymmetric hashing search and reordering.')
|
||||||
|
|
||||||
|
if not partioning_trainsize:
|
||||||
|
partioning_trainsize = data_pool['embedding'].shape[0] // 10
|
||||||
|
if not num_leaves:
|
||||||
|
num_leaves = int(np.sqrt(pool_size))
|
||||||
|
|
||||||
|
if not num_leaves_to_search:
|
||||||
|
num_leaves_to_search = max(num_leaves // 20, 1)
|
||||||
|
|
||||||
|
print('Partitioning params:')
|
||||||
|
print(f'num_leaves: {num_leaves}')
|
||||||
|
print(f'num_leaves_to_search: {num_leaves_to_search}')
|
||||||
|
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
|
||||||
|
searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
|
||||||
|
partioning_trainsize, num_leaves, num_leaves_to_search)
|
||||||
|
|
||||||
|
print('Finish training searcher')
|
||||||
|
searcher_savedir = opt.target_path
|
||||||
|
os.makedirs(searcher_savedir, exist_ok=True)
|
||||||
|
searcher.serialize(searcher_savedir)
|
||||||
|
print(f'Saved trained searcher under "{searcher_savedir}"')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--database',
|
||||||
|
'-d',
|
||||||
|
default='data/rdm/retrieval_databases/openimages',
|
||||||
|
type=str,
|
||||||
|
help='path to folder containing the clip feature of the database')
|
||||||
|
parser.add_argument('--target_path',
|
||||||
|
'-t',
|
||||||
|
default='data/rdm/searchers/openimages',
|
||||||
|
type=str,
|
||||||
|
help='path to the target folder where the searcher shall be stored.')
|
||||||
|
parser.add_argument('--knn',
|
||||||
|
'-k',
|
||||||
|
default=20,
|
||||||
|
type=int,
|
||||||
|
help='number of nearest neighbors, for which the searcher shall be optimized')
|
||||||
|
|
||||||
|
opt, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
train_searcher(opt,)
|
||||||
@ -0,0 +1,279 @@
|
|||||||
|
import argparse, os, sys, glob
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
from itertools import islice
|
||||||
|
from einops import rearrange
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
import time
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
from torch import autocast
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
|
||||||
|
|
||||||
|
def chunk(it, size):
|
||||||
|
it = iter(it)
|
||||||
|
return iter(lambda: tuple(islice(it, size)), ())
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_config(config, ckpt, verbose=False):
|
||||||
|
print(f"Loading model from {ckpt}")
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
if "global_step" in pl_sd:
|
||||||
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
if len(m) > 0 and verbose:
|
||||||
|
print("missing keys:")
|
||||||
|
print(m)
|
||||||
|
if len(u) > 0 and verbose:
|
||||||
|
print("unexpected keys:")
|
||||||
|
print(u)
|
||||||
|
|
||||||
|
model.cuda()
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
default="a painting of a virus monster playing guitar",
|
||||||
|
help="the prompt to render"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--outdir",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="dir to write results to",
|
||||||
|
default="outputs/txt2img-samples"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_grid",
|
||||||
|
action='store_true',
|
||||||
|
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_save",
|
||||||
|
action='store_true',
|
||||||
|
help="do not save individual samples. For speed measurements.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddim_steps",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="number of ddim sampling steps",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plms",
|
||||||
|
action='store_true',
|
||||||
|
help="use plms sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--laion400m",
|
||||||
|
action='store_true',
|
||||||
|
help="uses the LAION400M model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fixed_code",
|
||||||
|
action='store_true',
|
||||||
|
help="if enabled, uses the same starting code across samples ",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddim_eta",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_iter",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="sample this often",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--H",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="image height, in pixel space",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--W",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="image width, in pixel space",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--C",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="latent channels",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--f",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="downsampling factor",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_samples",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="how many samples to produce for each given prompt. A.k.a. batch size",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_rows",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="rows in the grid (default: n_samples)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scale",
|
||||||
|
type=float,
|
||||||
|
default=7.5,
|
||||||
|
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--from-file",
|
||||||
|
type=str,
|
||||||
|
help="if specified, load prompts from this file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default="configs/stable-diffusion/v1-inference.yaml",
|
||||||
|
help="path to config which constructs model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ckpt",
|
||||||
|
type=str,
|
||||||
|
default="models/ldm/stable-diffusion-v1/model.ckpt",
|
||||||
|
help="path to checkpoint of model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="the seed (for reproducible sampling)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--precision",
|
||||||
|
type=str,
|
||||||
|
help="evaluate at this precision",
|
||||||
|
choices=["full", "autocast"],
|
||||||
|
default="autocast"
|
||||||
|
)
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
if opt.laion400m:
|
||||||
|
print("Falling back to LAION 400M model...")
|
||||||
|
opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
|
||||||
|
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
|
||||||
|
opt.outdir = "outputs/txt2img-samples-laion400m"
|
||||||
|
|
||||||
|
seed_everything(opt.seed)
|
||||||
|
|
||||||
|
config = OmegaConf.load(f"{opt.config}")
|
||||||
|
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||||
|
|
||||||
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
if opt.plms:
|
||||||
|
sampler = PLMSSampler(model)
|
||||||
|
else:
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
|
||||||
|
os.makedirs(opt.outdir, exist_ok=True)
|
||||||
|
outpath = opt.outdir
|
||||||
|
|
||||||
|
batch_size = opt.n_samples
|
||||||
|
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||||
|
if not opt.from_file:
|
||||||
|
prompt = opt.prompt
|
||||||
|
assert prompt is not None
|
||||||
|
data = [batch_size * [prompt]]
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"reading prompts from {opt.from_file}")
|
||||||
|
with open(opt.from_file, "r") as f:
|
||||||
|
data = f.read().splitlines()
|
||||||
|
data = list(chunk(data, batch_size))
|
||||||
|
|
||||||
|
sample_path = os.path.join(outpath, "samples")
|
||||||
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
|
base_count = len(os.listdir(sample_path))
|
||||||
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
|
|
||||||
|
start_code = None
|
||||||
|
if opt.fixed_code:
|
||||||
|
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
||||||
|
|
||||||
|
precision_scope = autocast if opt.precision=="autocast" else nullcontext
|
||||||
|
with torch.no_grad():
|
||||||
|
with precision_scope("cuda"):
|
||||||
|
with model.ema_scope():
|
||||||
|
tic = time.time()
|
||||||
|
all_samples = list()
|
||||||
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
|
for prompts in tqdm(data, desc="data"):
|
||||||
|
uc = None
|
||||||
|
if opt.scale != 1.0:
|
||||||
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
|
if isinstance(prompts, tuple):
|
||||||
|
prompts = list(prompts)
|
||||||
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||||
|
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=opt.n_samples,
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=opt.scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=opt.ddim_eta,
|
||||||
|
x_T=start_code)
|
||||||
|
|
||||||
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
if not opt.skip_save:
|
||||||
|
for x_sample in x_samples_ddim:
|
||||||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||||
|
os.path.join(sample_path, f"{base_count:05}.png"))
|
||||||
|
base_count += 1
|
||||||
|
|
||||||
|
if not opt.skip_grid:
|
||||||
|
all_samples.append(x_samples_ddim)
|
||||||
|
|
||||||
|
if not opt.skip_grid:
|
||||||
|
# additionally, save as grid
|
||||||
|
grid = torch.stack(all_samples, 0)
|
||||||
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
|
grid = make_grid(grid, nrow=n_rows)
|
||||||
|
|
||||||
|
# to image
|
||||||
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||||
|
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||||
|
grid_count += 1
|
||||||
|
|
||||||
|
toc = time.time()
|
||||||
|
|
||||||
|
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
||||||
|
f" \nEnjoy.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||