Skip to content

Commit 6559413

Browse files
authored
Serve _upload files with content-disposition: attachment (#6303)
* Serve _upload files with content-disposition: attachment * All files served from the /_upload endpoint will now trigger download * PDF files are exempt, but they will always be served with the content-type set * If an application needs some other behavior, they are encouraged to mount their own StaticFiles instance to handle it. * remove stray breakpoint * test_upload: assert that files are properly downloaded * Do not add or replace Content-Type header If the Content-Type is missing or not application/pdf, then we serve the file with Content-Disposition: attachment. Iterate through the headers in a way that avoids adding duplicate headers.
1 parent 8fb2fb4 commit 6559413

File tree

5 files changed

+241
-41
lines changed

5 files changed

+241
-41
lines changed

packages/reflex-base/src/reflex_base/constants/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
"RouteArgType",
115115
"RouteRegex",
116116
"RouteVar",
117+
"RunningMode",
117118
"SocketEvent",
118119
"StateManagerMode",
119120
"Templates",

packages/reflex-components-core/src/reflex_components_core/core/_upload.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import contextlib
77
import dataclasses
88
from collections import deque
9-
from collections.abc import AsyncGenerator, AsyncIterator
9+
from collections.abc import AsyncGenerator, AsyncIterator, MutableMapping
1010
from pathlib import Path
1111
from typing import TYPE_CHECKING, Any, BinaryIO, cast
1212

@@ -23,6 +23,8 @@
2323
from typing_extensions import Self
2424

2525
if TYPE_CHECKING:
26+
from reflex_base.utils.types import ASGIApp, Receive, Scope, Send
27+
2628
from reflex.app import App
2729

2830

@@ -575,6 +577,62 @@ async def _upload_chunk_file(
575577
return Response(status_code=202)
576578

577579

580+
header_content_disposition = b"content-disposition"
581+
header_content_type = b"content-type"
582+
header_x_content_type_options = b"x-content-type-options"
583+
584+
585+
class UploadedFilesHeadersMiddleware:
586+
"""ASGI middleware that adds security headers to uploaded file responses."""
587+
588+
def __init__(self, app: ASGIApp) -> None:
589+
"""Wrap an ASGI application with upload security headers.
590+
591+
Args:
592+
app: The ASGI application to wrap.
593+
"""
594+
self.app = app
595+
596+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
597+
"""Add Content-Disposition and X-Content-Type-Options headers.
598+
599+
Args:
600+
scope: The ASGI scope.
601+
receive: The ASGI receive callable.
602+
send: The ASGI send callable.
603+
"""
604+
if scope["type"] != "http":
605+
await self.app(scope, receive, send)
606+
return
607+
608+
async def send_with_headers(message: MutableMapping[str, Any]) -> None:
609+
if message["type"] == "http.response.start":
610+
content_disposition = None
611+
content_type = None
612+
headers = [(header_x_content_type_options, b"nosniff")]
613+
for header_name, header_value in message.get("headers", []):
614+
lower_name = header_name.lower()
615+
if lower_name == header_content_disposition:
616+
content_disposition = header_value.lower()
617+
# Always append content-disposition header if non-empty.
618+
continue
619+
if lower_name == header_x_content_type_options:
620+
# Always replace this value with "nosniff", so ignore existing value.
621+
continue
622+
if lower_name == header_content_type:
623+
content_type = header_value.lower()
624+
headers.append((header_name, header_value))
625+
if content_type != b"application/pdf":
626+
# Unknown content or non-PDF forces download.
627+
content_disposition = b"attachment"
628+
if content_disposition:
629+
headers.append((header_content_disposition, content_disposition))
630+
message = {**message, "headers": headers}
631+
await send(message)
632+
633+
await self.app(scope, receive, send_with_headers)
634+
635+
578636
def upload(app: App):
579637
"""Upload files, dispatching to buffered or streaming handling.
580638

reflex/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@
6868
from starlette.staticfiles import StaticFiles
6969
from typing_extensions import Unpack
7070

71+
from reflex._upload import UploadedFilesHeadersMiddleware, upload
7172
from reflex._upload import UploadFile as UploadFile
72-
from reflex._upload import upload
7373
from reflex.admin import AdminDash
7474
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
7575
from reflex.compiler import compiler
@@ -714,7 +714,7 @@ def _add_optional_endpoints(self):
714714
# To access uploaded files.
715715
self._api.mount(
716716
str(constants.Endpoint.UPLOAD),
717-
StaticFiles(directory=get_upload_dir()),
717+
UploadedFilesHeadersMiddleware(StaticFiles(directory=get_upload_dir())),
718718
name="uploaded_files",
719719
)
720720

tests/integration/test_media.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@ def img_from_url(self) -> Image.Image:
5555
img_bytes = img_resp.content
5656
return Image.open(io.BytesIO(img_bytes))
5757

58+
@rx.var
59+
def generated_image(self) -> str:
60+
# Generate a 150x150 red PNG and write it to the upload directory.
61+
img = Image.new("RGB", (150, 150), "red")
62+
upload_dir = rx.get_upload_dir()
63+
upload_dir.mkdir(parents=True, exist_ok=True)
64+
img.save(upload_dir / "generated.png")
65+
return "generated.png"
66+
5867
app = rx.App()
5968

6069
@app.add_page
@@ -72,19 +81,53 @@ def index():
7281
rx.image(src=State.img_gif, alt="GIF image", id="gif"),
7382
rx.image(src=State.img_webp, alt="WEBP image", id="webp"),
7483
rx.image(src=State.img_from_url, alt="Image from URL", id="from_url"),
84+
rx.image(
85+
src=rx.get_upload_url(State.generated_image),
86+
alt="Uploaded image",
87+
id="uploaded",
88+
),
7589
)
7690

7791

92+
def check_image_loaded(
93+
driver, img, expected_width: int = 200, expected_height: int = 200
94+
) -> bool:
95+
"""Check whether an image element has fully loaded with expected dimensions.
96+
97+
Args:
98+
driver: WebDriver instance.
99+
img: The image WebElement.
100+
expected_width: Expected natural width.
101+
expected_height: Expected natural height.
102+
103+
Returns:
104+
True if the image is complete and matches the expected dimensions.
105+
"""
106+
return driver.execute_script(
107+
"return arguments[0].complete "
108+
'&& typeof arguments[0].naturalWidth != "undefined" '
109+
"&& arguments[0].naturalWidth === arguments[1] "
110+
'&& typeof arguments[0].naturalHeight != "undefined" '
111+
"&& arguments[0].naturalHeight === arguments[2]",
112+
img,
113+
expected_width,
114+
expected_height,
115+
)
116+
117+
78118
@pytest.fixture
79-
def media_app(tmp_path) -> Generator[AppHarness, None, None]:
119+
def media_app(tmp_path, monkeypatch) -> Generator[AppHarness, None, None]:
80120
"""Start MediaApp app at tmp_path via AppHarness.
81121
82122
Args:
83123
tmp_path: pytest tmp_path fixture
124+
monkeypatch: pytest monkeypatch fixture
84125
85126
Yields:
86127
running AppHarness instance
87128
"""
129+
monkeypatch.setenv("REFLEX_UPLOADED_FILES_DIR", str(tmp_path / "uploads"))
130+
88131
with AppHarness.create(
89132
root=tmp_path,
90133
app_source=MediaApp,
@@ -116,52 +159,44 @@ def test_media_app(media_app: AppHarness):
116159
gif_img = driver.find_element(By.ID, "gif")
117160
webp_img = driver.find_element(By.ID, "webp")
118161
from_url_img = driver.find_element(By.ID, "from_url")
119-
120-
def check_image_loaded(img, check_width=" == 200", check_height=" == 200"):
121-
return driver.execute_script(
122-
"console.log(arguments); return arguments[1].complete "
123-
'&& typeof arguments[1].naturalWidth != "undefined" '
124-
f"&& arguments[1].naturalWidth {check_width} ",
125-
'&& typeof arguments[1].naturalHeight != "undefined" '
126-
f"&& arguments[1].naturalHeight {check_height} ",
127-
img,
128-
)
162+
uploaded_img = driver.find_element(By.ID, "uploaded")
129163

130164
default_img_src = default_img.get_attribute("src")
131165
assert default_img_src is not None
132166
assert default_img_src.startswith("data:image/png;base64")
133-
assert check_image_loaded(default_img)
167+
assert check_image_loaded(driver, default_img)
134168

135169
bmp_img_src = bmp_img.get_attribute("src")
136170
assert bmp_img_src is not None
137171
assert bmp_img_src.startswith("data:image/bmp;base64")
138-
assert check_image_loaded(bmp_img)
172+
assert check_image_loaded(driver, bmp_img)
139173

140174
jpg_img_src = jpg_img.get_attribute("src")
141175
assert jpg_img_src is not None
142176
assert jpg_img_src.startswith("data:image/jpeg;base64")
143-
assert check_image_loaded(jpg_img)
177+
assert check_image_loaded(driver, jpg_img)
144178

145179
png_img_src = png_img.get_attribute("src")
146180
assert png_img_src is not None
147181
assert png_img_src.startswith("data:image/png;base64")
148-
assert check_image_loaded(png_img)
182+
assert check_image_loaded(driver, png_img)
149183

150184
gif_img_src = gif_img.get_attribute("src")
151185
assert gif_img_src is not None
152186
assert gif_img_src.startswith("data:image/gif;base64")
153-
assert check_image_loaded(gif_img)
187+
assert check_image_loaded(driver, gif_img)
154188

155189
webp_img_src = webp_img.get_attribute("src")
156190
assert webp_img_src is not None
157191
assert webp_img_src.startswith("data:image/webp;base64")
158-
assert check_image_loaded(webp_img)
192+
assert check_image_loaded(driver, webp_img)
159193

160194
from_url_img_src = from_url_img.get_attribute("src")
161195
assert from_url_img_src is not None
162196
assert from_url_img_src.startswith("data:image/jpeg;base64")
163-
assert check_image_loaded(
164-
from_url_img,
165-
check_width=" == 200",
166-
check_height=" == 300",
167-
)
197+
assert check_image_loaded(driver, from_url_img, expected_height=300)
198+
199+
uploaded_img_src = uploaded_img.get_attribute("src")
200+
assert uploaded_img_src is not None
201+
assert "generated.png" in uploaded_img_src
202+
assert check_image_loaded(driver, uploaded_img, 150, 150)

0 commit comments

Comments
 (0)