Skip to content

Commit bd37e42

Browse files
Dataviz/add slider to python api (#2)
* feat: create a image slider dataviz * cleanup the js code * fix typos * Removing JSON parse * add GeneratedImages * move code to show method instead * bump version * update notebook * update readme with slider information * remove bounding box from last image * change image_slider.gif * add a loading animation and fix a few minor issues * fix typo * try to display loading before image slider * make loading enabled at initialization * try display_id=42 * try to display both loading and slider at once * try with update display * Update README.md * update notebook Co-authored-by: André Batista <hi@andrebatista.dev>
1 parent 34dfd41 commit bd37e42

6 files changed

Lines changed: 325 additions & 130 deletions

File tree

README.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,8 @@ output['sample']
5656

5757
![](assets/corgi_eiffel_tower.png)
5858

59-
You can also check the image that the diffusion process generated in the end of each step.
60-
61-
For example, to see the image from step 10:
62-
```python
63-
output['all_samples_during_generation'][10]
64-
```
65-
![](assets/corgi_eiffel_tower_step10.png)
59+
You can also check all the images that the diffusion process generated at the end of each step.
60+
![](assets/image_slider.gif)
6661

6762
To check how a token in the input `prompt` influenced the generation, you can check the token attribution scores:
6863
```python
@@ -132,7 +127,7 @@ The token attributions are now computed only for the area specified in the image
132127
Check other functionalities and more implementation examples in [here](https://github.com/JoaoLages/diffusers-interpret/blob/main/notebooks/).
133128

134129
## Future Development
135-
- [ ] Add interactive display of all the images that were generated in the diffusion process
130+
- [x] ~~Add interactive display of all the images that were generated in the diffusion process~~
136131
- [ ] Add interactive bounding-box and token attributions visualization
137132
- [ ] Add unit tests
138133
- [ ] Add example for `diffusers_interpret.LDMTextToImagePipelineExplainer`
@@ -141,3 +136,7 @@ Check other functionalities and more implementation examples in [here](https://g
141136

142137
## Contributing
143138
Feel free to open an [Issue](https://github.com/JoaoLages/diffusers-interpret/issues) or create a [Pull Request](https://github.com/JoaoLages/diffusers-interpret/pulls) and let's get started 🚀
139+
140+
## Credits
141+
142+
A special thanks to [@andrewizbatista](https://github.com/andrewizbatista) for creating a great [image slider](https://github.com/JoaoLages/diffusers-interpret/pull/1) to show all the generated images during diffusion! 💪

assets/image_slider.gif

4.54 MB
Loading

notebooks/stable-diffusion-example.ipynb

Lines changed: 147 additions & 98 deletions
Large diffs are not rendered by default.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
setup(
1111
name='diffusers-interpret',
12-
version='0.1.0',
12+
version='0.2.0',
1313
description='diffusers-interpret: model explainability for 🤗 Diffusers',
1414
long_description=long_description,
1515
long_description_content_type='text/markdown',

src/diffusers_interpret/explainer.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
from abc import ABC, abstractmethod
22
from typing import List, Optional, Union, Dict, Any, Tuple, Set
33

4-
from PIL import ImageDraw
5-
64
import torch
5+
from PIL import ImageDraw
76
from diffusers import DiffusionPipeline
87
from transformers import BatchEncoding, PreTrainedTokenizerBase
98

109
from diffusers_interpret.attribution import gradient_x_inputs_attribution
10+
from diffusers_interpret.generated_images import GeneratedImages
1111
from diffusers_interpret.utils import clean_token_from_prefixes_and_suffixes, transform_images_to_pil_format
1212

1313

1414
class BasePipelineExplainer(ABC):
15-
def __init__(self, pipe: DiffusionPipeline, verbose: bool = True):
15+
def __init__(self, pipe: DiffusionPipeline, verbose: bool = True) -> None:
1616
self.pipe = pipe
1717
self.verbose = verbose
1818

@@ -117,33 +117,35 @@ def __call__(
117117
if self.verbose:
118118
print("Done!")
119119

120-
# convert to PIL Image if requested
121-
# also draw bounding box if requested
122-
if output_type == "pil":
123-
images_with_bounding_box = []
124-
all_samples = output['all_samples_during_generation'] or [output['sample']]
125-
for list_im in transform_images_to_pil_format(all_samples, self.pipe):
126-
batch_images = []
127-
for im in list_im:
128-
if explanation_2d_bounding_box:
129-
draw = ImageDraw.Draw(im)
130-
draw.rectangle(explanation_2d_bounding_box, outline="red")
131-
batch_images.append(im)
132-
images_with_bounding_box.append(batch_images)
133-
134-
if output['all_samples_during_generation']:
135-
output['all_samples_during_generation'] = images_with_bounding_box
136-
output['sample'] = output['all_samples_during_generation'][-1]
137-
else:
138-
output['sample'] = images_with_bounding_box[-1]
139-
140120
if batch_size == 1:
141121
# squash batch dimension
142122
for k in ['sample', 'token_attributions', 'normalized_token_attributions']:
143123
output[k] = output[k][0]
144124
if output['all_samples_during_generation']:
145125
output['all_samples_during_generation'] = [b[0] for b in output['all_samples_during_generation']]
146126

127+
# convert to PIL Image if requested
128+
# also draw bounding box in the last image if requested
129+
if output['all_samples_during_generation'] or output_type == "pil":
130+
all_samples = GeneratedImages(
131+
all_generated_images=output['all_samples_during_generation'] or [output['sample']],
132+
pipe=self.pipe,
133+
remove_batch_dimension=batch_size==1,
134+
prepare_image_slider=bool(output['all_samples_during_generation'])
135+
)
136+
if output['all_samples_during_generation']:
137+
output['all_samples_during_generation'] = all_samples
138+
sample = output['all_samples_during_generation'][-1]
139+
else:
140+
sample = all_samples[-1]
141+
142+
if explanation_2d_bounding_box:
143+
draw = ImageDraw.Draw(sample)
144+
draw.rectangle(explanation_2d_bounding_box, outline="red")
145+
146+
if output_type == "pil":
147+
output['sample'] = sample
148+
147149
return output
148150

149151
@property
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import base64
2+
import json
3+
import os
4+
from typing import List, Union
5+
6+
import torch
7+
from IPython import display as d
8+
from PIL.Image import Image
9+
from diffusers import DiffusionPipeline
10+
11+
import diffusers_interpret
12+
from diffusers_interpret.utils import transform_images_to_pil_format
13+
14+
15+
class GeneratedImages:
16+
def __init__(
17+
self,
18+
all_generated_images: List[torch.Tensor],
19+
pipe: DiffusionPipeline,
20+
remove_batch_dimension: bool = True,
21+
prepare_image_slider: bool = True
22+
) -> None:
23+
24+
assert all_generated_images, "Can't create GeneratedImages object with empty `all_generated_images`"
25+
26+
# Convert images to PIL and draw box if requested
27+
self.images = []
28+
for list_im in transform_images_to_pil_format(all_generated_images, pipe):
29+
batch_images = []
30+
for im in list_im:
31+
batch_images.append(im)
32+
33+
if remove_batch_dimension:
34+
self.images.extend(batch_images)
35+
else:
36+
self.images.append(batch_images)
37+
38+
self.loading_iframe = None
39+
self.image_slider_iframe = None
40+
if prepare_image_slider:
41+
self.prepare_image_slider()
42+
43+
def prepare_image_slider(self) -> None:
44+
"""
45+
Creates auxiliary HTML file to be displayed in self.__repr__
46+
"""
47+
48+
# Get data dir
49+
image_slider_dir = os.path.join(os.path.dirname(diffusers_interpret.__file__), "dataviz", "image-slider")
50+
51+
# Convert images to base64
52+
json_payload = []
53+
for i, image in enumerate(self.images):
54+
image.save(f"{image_slider_dir}/to_delete.png")
55+
with open(f"{image_slider_dir}/to_delete.png", "rb") as image_file:
56+
json_payload.append(
57+
{"image": "data:image/png;base64," + base64.b64encode(image_file.read()).decode('utf-8')}
58+
)
59+
os.remove(f"{image_slider_dir}/to_delete.png")
60+
61+
# get HTML file
62+
with open(os.path.join(image_slider_dir, "index.html")) as fp:
63+
html = fp.read()
64+
65+
# get CSS file
66+
with open(os.path.join(image_slider_dir, "css/index.css")) as fp:
67+
css = fp.read()
68+
69+
# get JS file
70+
with open(os.path.join(image_slider_dir, "js/index.js")) as fp:
71+
js = fp.read()
72+
73+
# replace CSS text in CSS file
74+
html = html.replace("""<link href="css/index.css" rel="stylesheet" />""",
75+
f"""<style type="text/css">\n{css}</style>""")
76+
77+
# replace JS text in HTML file
78+
html = html.replace("""<script type="text/javascript" src="js/index.js"></script>""", ""
79+
f"""<script type="text/javascript">\n{js}</script>""")
80+
81+
# get html with image slider JS call
82+
index = html.find("<!-- INSERT STARTING SCRIPT HERE -->")
83+
add = """
84+
<script type="text/javascript">
85+
((d) => {
86+
const $body = d.querySelector("body");
87+
88+
if ($body) {
89+
$body.addEventListener("INITIALIZE_IS_READY", ({ detail }) => {
90+
const initialize = detail?.initialize ?? null;
91+
92+
if (initialize) initialize(%s);
93+
});
94+
}
95+
})(document);
96+
</script>
97+
""" % json.dumps(json_payload)
98+
html_with_image_slider = html[:index] + add + html[index:]
99+
100+
# save files and load IFrame to be displayed in self.__repr__
101+
with open(os.path.join(image_slider_dir, "loading.html"), 'w') as fp:
102+
fp.write(html)
103+
with open(os.path.join(image_slider_dir, "final.html"), 'w') as fp:
104+
fp.write(html_with_image_slider)
105+
106+
self.loading_iframe = d.IFrame(
107+
os.path.relpath(
108+
os.path.join(os.path.dirname(diffusers_interpret.__file__), "dataviz", "image-slider", "loading.html"),
109+
'.'
110+
),
111+
width="100%", height="400px"
112+
)
113+
114+
self.image_slider_iframe = d.IFrame(
115+
os.path.relpath(
116+
os.path.join(os.path.dirname(diffusers_interpret.__file__), "dataviz", "image-slider", "final.html"),
117+
'.'
118+
),
119+
width="100%", height="400px"
120+
)
121+
122+
def __getitem__(self, item: int) -> Union[Image, List[Image]]:
123+
return self.images[item]
124+
125+
def show(self, width: Union[str, int] = "100%", height: Union[str, int] = "400px") -> None:
126+
127+
if len(self.images) == 0:
128+
raise Exception("`self.images` is an empty list, can't show any images")
129+
130+
if isinstance(self.images[0], list):
131+
raise NotImplementedError("GeneratedImages.show visualization is not supported "
132+
"when `self.images` is a list of lists of images")
133+
134+
if self.image_slider_iframe is None:
135+
self.prepare_image_slider()
136+
137+
# display loading
138+
self.loading_iframe.width = width
139+
self.loading_iframe.height = height
140+
display = d.display(self.loading_iframe, display_id=42)
141+
142+
# display image slider
143+
self.image_slider_iframe.width = width
144+
self.image_slider_iframe.height = height
145+
display.update(self.image_slider_iframe)

0 commit comments

Comments
 (0)