Skip to content

Commit d506289

Browse files
committed
Upload code
1 parent 6353519 commit d506289

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+4941
-3
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,4 @@ cython_debug/
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
#.idea/
161+
.DS_Store

Diff for: LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2023 Pablo Marcos
3+
Copyright (c) 2023
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

Diff for: README.md

+129-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,129 @@
1-
# ovam
2-
Open-Vocabulary Attention Maps with Token Optimization for Semantic Segmentation in Diffusion Models
1+
# Open-Vocabulary Attention Maps (OVAM)
2+
3+
**Open-Vocabulary Attention Maps with Token Optimization for Semantic Segmentation in Diffusion Models**
4+
5+
6+
[]([![arXiv](https://img.shields.io/badge/arXiv-abcd.efgh-b31b1b.svg)](https://arxiv.org/abs/abcd.efgh))
7+
8+
> Links have been removed for anonymity.
9+
10+
![teaser](docs/assets/teaser.svg)
11+
12+
In [our paper](https://arxig.org), we introduce Open-Vocabulary Attention Maps (OVAM), a training-free extension for text-to-image diffusion models to generate text-attribution maps based on open vocabulary descriptions. Also, we introduce a token optimization process to the creation of accurate attention maps, improving the performance of existing semantic segmentation methods based on diffusion cross-attention maps.
13+
14+
![diagram](docs/assets/diagram-OVAM.svg)
15+
16+
## Installation
17+
18+
Create a new virtual or conda environment using (if applicable) and activate it. As example, using `venv`:
19+
20+
```bash
21+
# Install a Python environment (Ensure 3.8 or higher)
22+
python -m venv venv
23+
source venv/bin/activate
24+
pip install --upgrade pip wheel
25+
```
26+
27+
Install Pytorch with a compatible CUDA or other backend and [Diffusers 0.20](https://pypi.org/project/diffusers/0.20.2/). In our experiments, we tested the code in Ubuntu with CUDA 11.8 and in MacOS with MPS backend.
28+
29+
```bash
30+
# Install PyTorch with CUDA 11.8
31+
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
32+
```
33+
34+
```bash
35+
# Or Pytorch with MPS backend for MacOS
36+
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0
37+
```
38+
39+
Install Python dependencies using project file or alternatively install them from `requirements.txt`:
40+
41+
```bash
42+
# Install using pyproject.toml
43+
pip install .
44+
```
45+
46+
Or alternatively, install dependencies from `requirements.txt` and add OVAM to your PYTHONPATH.
47+
48+
## Getting started
49+
50+
The jupyter notebook [examples/getting_started.ipynb](./examples/getting_started.ipynb) contains a full example of how to use OVAM with Stable Diffusion. In this section, we will show a simplified version of the notebook.
51+
52+
### Setup
53+
Import related libraries and load Stable Diffusion
54+
55+
```python
56+
import torch
57+
import matplotlib.pyplot as plt
58+
from diffusers import StableDiffusionPipeline
59+
from ovam.stable_diffusion import StableDiffusionHooker
60+
from ovam.utils import set_seed
61+
62+
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
63+
pipe = pipe.to("mps") #mps, cuda, ...
64+
```
65+
66+
Generate and image with Stable Diffusion and store the attention maps using OVAM hooker.
67+
68+
```python
69+
with StableDiffusionHooker(pipe) as hooker:
70+
set_seed(123456)
71+
out = pipe("monkey with hat walking")
72+
image = out.images[0]
73+
```
74+
### Generate and attention map with open vocabulary
75+
76+
Extract attention maps for the attribution prompt `monkey with hat walking and mouth`:
77+
78+
```python
79+
ovam_evaluator = hooker.get_ovam_callable(
80+
expand_size=(512, 512)
81+
) # You can configure OVAM here (aggregation, activations, size, ...)
82+
83+
with torch.no_grad():
84+
attention_maps = ovam_evaluator("monkey with hat walking and mouth")
85+
attention_maps = attention_maps[0].cpu().numpy() # (8, 512, 512)
86+
```
87+
88+
Have been generated 8 attention maps for the tokens: `0:<SoT>, 1:monkey, 2:with, 3:hat, 4:walking, 5:and, 6:mouth, 7:<EoT>`. Plot attention maps for words `monkey`, `hat` and `mouth`:
89+
90+
```python
91+
# Get maps for monkey, hat and mouth
92+
monkey = attention_maps[1]
93+
hat = attention_maps[3]
94+
mouth = attention_maps[6]
95+
96+
# Plot using matplotlib
97+
fig, (ax0, ax1, ax2, ax3) = plt.subplots(1, 4, figsize=(20, 5))
98+
ax0.imshow(image)
99+
ax1.imshow(monkey, alpha=monkey / monkey.max())
100+
ax2.imshow(hat, alpha=hat / hat.max())
101+
ax3.imshow(mouth, alpha=mouth / mouth.max())
102+
plt.show()
103+
```
104+
Result (matplotlib code simplified, full in [examples/getting_started.ipynb](./examples/getting_started.ipynb)):
105+
![result](docs/assets/attention_maps.svg)
106+
107+
### Token optimization
108+
109+
OVAM library include code to optimize the tokens to improve the attention maps. Given an image generated with Stable Diffusion using the text `a photograph of a cat in a park`, we optimized a cat token for obtaining a mask of the cat in the image (full example in notebook).
110+
111+
![Token optimization](docs/assets/optimized_training_attention.svg)
112+
113+
This token, can be later used for generating a mask of the cat in other testing images. For example, in this image generated with the text `cat perched on the sofa looking out of the window`.
114+
115+
![Token optimization](docs/assets/optimized_testing_attention.svg)
116+
117+
### Different Stable Diffusion versions
118+
119+
The current code have been tested with Stable Diffusion 1.5, 2.0 base and 2.1 in Diffusers 0.20. We provide a module ovam/base with utility classes for adapt OVAM to other Diffusion Models.
120+
121+
## Experiments
122+
123+
## Aknowledgements
124+
125+
We want to thank the authors of [DAAM](https://github.com/castorini/daam) for their helpful code. A big thanks also to the open-source community of [HuggingFace](https://huggingface.co/docs/diffusers/index), [PyTorch](https://pytorch.org/), and RunwayML for making [Stable Diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) available. We also aknowledge the work of the teams behind [DatasetDM](https://github.com/showlab/DatasetDM), [DiffuMask](https://github.com/weijiawu/DiffuMask) and [Grounded Diffusion](https://github.com/Lipurple/Grounded-Diffusion), which we used in our experiments.
126+
127+
## Citation
128+
129+
> Pending publication.

Diff for: docs/assets/attention_maps.svg

+613
Loading

Diff for: docs/assets/cat.jpg

80.5 KB
Loading

Diff for: docs/assets/cat_annotation.png

2.31 KB
Loading

Diff for: docs/assets/cat_optimized_token.npy

6.13 KB
Binary file not shown.

Diff for: docs/assets/diagram-OVAM.svg

+1
Loading

Diff for: docs/assets/optimized_testing_attention.svg

+588
Loading

Diff for: docs/assets/optimized_training_attention.svg

+671
Loading

Diff for: docs/assets/teaser.svg

+1
Loading

Diff for: examples/getting_started.ipynb

+623
Large diffs are not rendered by default.

Diff for: ovam/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .stable_diffusion_sa import StableDiffusionHookerSA as StableDiffusionHooker
2+
3+
__version__ = "0.0.1"
4+
5+
__all__ = ["StableDiffusionHooker", "__version__"]

Diff for: ovam/base/attention_storage.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""Class in charge of storing the hidden states of a block.
2+
3+
The class OnlineAttentionStorage allows a simple version that
4+
stores all the hidden states in memory. The AttentionStorage
5+
class is a generic class that can be used to implement more
6+
complex storage classes.
7+
8+
"""
9+
from typing import TYPE_CHECKING, Iterable, Optional, List
10+
11+
if TYPE_CHECKING:
12+
import torch
13+
14+
__all__ = ["AttentionStorage", "OnlineAttentionStorage"]
15+
16+
17+
class AttentionStorage:
18+
"""Generic class for storing hidden states of upsample/downsample block.
19+
20+
Attributes
21+
----------
22+
name: str
23+
The name of the block in the UNet.
24+
"""
25+
26+
def __init__(self, name: Optional[str] = None) -> None:
27+
self.name = name
28+
29+
def store(self, hidden_states: "torch.Tensor") -> None:
30+
"""Stores the hidden states.
31+
32+
Arguments
33+
---------
34+
hidden_states: List[torch.Tensor]
35+
The hidden states of a block generated by an image. The
36+
hidden states are stored in the order they are passed.
37+
"""
38+
raise NotImplementedError
39+
40+
def __len__(self) -> int:
41+
"""Returns the number of images stored"""
42+
raise NotImplementedError
43+
44+
def __getitem__(self, idx: int) -> "torch.Tensor":
45+
"""Returns the hidden state at the given index."""
46+
raise NotImplementedError
47+
48+
def __iter__(self) -> Iterable["torch.Tensor"]:
49+
"""Returns an iterator over the stored hidden states."""
50+
for i in range(len(self)):
51+
yield self[i]
52+
53+
def clear(self) -> None:
54+
"""Clears the stored hidden states."""
55+
raise NotImplementedError
56+
57+
58+
class OnlineAttentionStorage(AttentionStorage):
59+
"""Class to store the hidden states in memory.
60+
61+
Attributes
62+
----------
63+
block_name: str
64+
The name of the block in the UNet.
65+
"""
66+
67+
def __init__(self, name: Optional[str] = None):
68+
super().__init__(name)
69+
self.hidden_states: List["torch.Tensor"] = []
70+
71+
def store(self, hidden_states: "torch.Tensor") -> None:
72+
"""Stores the hidden states.
73+
74+
Arguments
75+
---------
76+
hidden_states: List[torch.Tensor]
77+
The hidden states of a block generated by an image. The
78+
hidden states are stored in the order they are passed.
79+
"""
80+
self.hidden_states.append(hidden_states)
81+
82+
def __len__(self) -> int:
83+
return len(self.hidden_states)
84+
85+
def __getitem__(self, idx: int) -> "torch.Tensor":
86+
return self.hidden_states[idx]
87+
88+
def clear(self) -> None:
89+
"""Clears the stored hidden states."""
90+
self.hidden_states.clear()

Diff for: ovam/base/block_hooker.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import TYPE_CHECKING
2+
3+
from .hooker import ObjectHooker, ModuleType
4+
from .attention_storage import OnlineAttentionStorage
5+
6+
if TYPE_CHECKING:
7+
import torch
8+
from .daam_block import DAAMBlock
9+
10+
11+
class BlockHooker(ObjectHooker["ModuleType"]):
12+
"""Hooker for the CrossAttention blocks.
13+
14+
Monkey patches the forward method of the cross attention blocks of the
15+
Stable Diffusion UNET.
16+
17+
Arguments
18+
---------
19+
20+
module: CrossAttention
21+
Cross Attention moduled to be hooked.
22+
block_index: int
23+
Block index
24+
25+
Attributes
26+
----------
27+
module: CrossAttention
28+
Cross Attention module hooked
29+
block_index: int
30+
Block index
31+
hidden_states: List[torch.Tensor]
32+
List of hidden states hoked with size [ h*w ] x n_heads, where
33+
`h*w` is the size flattended of the unet hidden state through the block,
34+
(equal to h*w / (2**2*factor)) and n_heads the number of attention heads
35+
of the module.
36+
37+
Note
38+
----
39+
This class is based on the original implementation `daam.trace.UNetCrossAttentionHooker`.
40+
"""
41+
42+
# Default class to store the hidden states (in memory)
43+
STORAGE_CLASS = OnlineAttentionStorage
44+
45+
def __init__(self, module: "ModuleType", name: str):
46+
super().__init__(module)
47+
self.name = name
48+
self.hidden_states = self.STORAGE_CLASS(name=name)
49+
50+
def __repr__(self):
51+
return f"{self.__class__.__name__}({self.name})"
52+
53+
def store_hidden_states(self) -> None:
54+
"""Stores the hidden states in the parent trace"""
55+
raise NotImplementedError
56+
57+
def clear(self) -> None:
58+
"""Clear the hidden states"""
59+
self.hidden_states.clear()
60+
61+
def _hook_impl(self) -> None:
62+
"""Monkey patches the forward method in the cross attention block"""
63+
self.monkey_patch("forward", self._hooked_forward)
64+
65+
def _hooked_forward(
66+
hk_self: "BlockHooker",
67+
_: "ModuleType",
68+
hidden_states: "torch.Tensor",
69+
):
70+
"""Hooked forward of the cross attention module.
71+
72+
Stores the hidden states and perform the original attention.
73+
"""
74+
raise NotImplementedError
75+
76+
def daam_block(self) -> "DAAMBlock":
77+
"""Builds a DAAMBlock with the current hidden states.
78+
79+
Arguments
80+
---------
81+
**kwargs:
82+
Arguments passed to the `DAAMBlock` constructor.
83+
"""
84+
85+
raise NotImplementedError

Diff for: ovam/base/daam_block.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import TYPE_CHECKING
2+
3+
from torch import nn
4+
5+
if TYPE_CHECKING:
6+
from .attention_storage import AttentionStorage
7+
8+
9+
class DAAMBlock(nn.Module):
10+
"""Generic DAAMBlock used to save the hidden states of the cross attention blocks.
11+
12+
Should be implemented by each of the different architectures.
13+
It is used to save the hidden states of the cross attention blocks and to
14+
build a callable DAAM function.
15+
"""
16+
17+
def __init__(
18+
self,
19+
hidden_states: "AttentionStorage",
20+
name: str,
21+
):
22+
super().__init__()
23+
self.name = name
24+
self.hidden_states = hidden_states
25+
26+
def forward(self, x):
27+
"""Compute the attention for a given input x"""
28+
29+
return NotImplementedError
30+
31+
def store_hidden_states(self) -> None:
32+
"""Stores the hidden states in the parent trace"""
33+
raise NotImplementedError

0 commit comments

Comments
 (0)