Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/reflex-base/src/reflex_base/constants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
"RouteArgType",
"RouteRegex",
"RouteVar",
"RunningMode",
"SocketEvent",
"StateManagerMode",
"Templates",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import contextlib
import dataclasses
from collections import deque
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import AsyncGenerator, AsyncIterator, MutableMapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, BinaryIO, cast

Expand All @@ -23,6 +23,8 @@
from typing_extensions import Self

if TYPE_CHECKING:
from reflex_base.utils.types import ASGIApp, Receive, Scope, Send

from reflex.app import App


Expand Down Expand Up @@ -575,6 +577,62 @@ async def _upload_chunk_file(
return Response(status_code=202)


header_content_disposition = b"content-disposition"
header_content_type = b"content-type"
header_x_content_type_options = b"x-content-type-options"


class UploadedFilesHeadersMiddleware:
"""ASGI middleware that adds security headers to uploaded file responses."""

def __init__(self, app: ASGIApp) -> None:
"""Wrap an ASGI application with upload security headers.

Args:
app: The ASGI application to wrap.
"""
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Add Content-Disposition and X-Content-Type-Options headers.

Args:
scope: The ASGI scope.
receive: The ASGI receive callable.
send: The ASGI send callable.
"""
if scope["type"] != "http":
await self.app(scope, receive, send)
return

async def send_with_headers(message: MutableMapping[str, Any]) -> None:
if message["type"] == "http.response.start":
content_disposition = None
content_type = None
headers = [(header_x_content_type_options, b"nosniff")]
for header_name, header_value in message.get("headers", []):
lower_name = header_name.lower()
if lower_name == header_content_disposition:
content_disposition = header_value.lower()
# Always append content-disposition header if non-empty.
continue
if lower_name == header_x_content_type_options:
# Always replace this value with "nosniff", so ignore existing value.
continue
if lower_name == header_content_type:
content_type = header_value.lower()
headers.append((header_name, header_value))
if content_type != b"application/pdf":
# Unknown content or non-PDF forces download.
content_disposition = b"attachment"
if content_disposition:
headers.append((header_content_disposition, content_disposition))
message = {**message, "headers": headers}
await send(message)

await self.app(scope, receive, send_with_headers)


def upload(app: App):
"""Upload files, dispatching to buffered or streaming handling.

Expand Down
4 changes: 2 additions & 2 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@
from starlette.staticfiles import StaticFiles
from typing_extensions import Unpack

from reflex._upload import UploadedFilesHeadersMiddleware, upload
from reflex._upload import UploadFile as UploadFile
from reflex._upload import upload
from reflex.admin import AdminDash
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
from reflex.compiler import compiler
Expand Down Expand Up @@ -714,7 +714,7 @@ def _add_optional_endpoints(self):
# To access uploaded files.
self._api.mount(
str(constants.Endpoint.UPLOAD),
StaticFiles(directory=get_upload_dir()),
UploadedFilesHeadersMiddleware(StaticFiles(directory=get_upload_dir())),
name="uploaded_files",
)

Expand Down
79 changes: 57 additions & 22 deletions tests/integration/test_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ def img_from_url(self) -> Image.Image:
img_bytes = img_resp.content
return Image.open(io.BytesIO(img_bytes))

@rx.var
def generated_image(self) -> str:
# Generate a 150x150 red PNG and write it to the upload directory.
img = Image.new("RGB", (150, 150), "red")
upload_dir = rx.get_upload_dir()
upload_dir.mkdir(parents=True, exist_ok=True)
img.save(upload_dir / "generated.png")
return "generated.png"

app = rx.App()

@app.add_page
Expand All @@ -72,19 +81,53 @@ def index():
rx.image(src=State.img_gif, alt="GIF image", id="gif"),
rx.image(src=State.img_webp, alt="WEBP image", id="webp"),
rx.image(src=State.img_from_url, alt="Image from URL", id="from_url"),
rx.image(
src=rx.get_upload_url(State.generated_image),
alt="Uploaded image",
id="uploaded",
),
)


def check_image_loaded(
driver, img, expected_width: int = 200, expected_height: int = 200
) -> bool:
"""Check whether an image element has fully loaded with expected dimensions.

Args:
driver: WebDriver instance.
img: The image WebElement.
expected_width: Expected natural width.
expected_height: Expected natural height.

Returns:
True if the image is complete and matches the expected dimensions.
"""
return driver.execute_script(
"return arguments[0].complete "
'&& typeof arguments[0].naturalWidth != "undefined" '
"&& arguments[0].naturalWidth === arguments[1] "
'&& typeof arguments[0].naturalHeight != "undefined" '
"&& arguments[0].naturalHeight === arguments[2]",
img,
expected_width,
expected_height,
)


@pytest.fixture
def media_app(tmp_path) -> Generator[AppHarness, None, None]:
def media_app(tmp_path, monkeypatch) -> Generator[AppHarness, None, None]:
"""Start MediaApp app at tmp_path via AppHarness.

Args:
tmp_path: pytest tmp_path fixture
monkeypatch: pytest monkeypatch fixture

Yields:
running AppHarness instance
"""
monkeypatch.setenv("REFLEX_UPLOADED_FILES_DIR", str(tmp_path / "uploads"))

with AppHarness.create(
root=tmp_path,
app_source=MediaApp,
Expand Down Expand Up @@ -116,52 +159,44 @@ def test_media_app(media_app: AppHarness):
gif_img = driver.find_element(By.ID, "gif")
webp_img = driver.find_element(By.ID, "webp")
from_url_img = driver.find_element(By.ID, "from_url")

def check_image_loaded(img, check_width=" == 200", check_height=" == 200"):
return driver.execute_script(
"console.log(arguments); return arguments[1].complete "
'&& typeof arguments[1].naturalWidth != "undefined" '
f"&& arguments[1].naturalWidth {check_width} ",
'&& typeof arguments[1].naturalHeight != "undefined" '
f"&& arguments[1].naturalHeight {check_height} ",
img,
)
uploaded_img = driver.find_element(By.ID, "uploaded")

default_img_src = default_img.get_attribute("src")
assert default_img_src is not None
assert default_img_src.startswith("data:image/png;base64")
assert check_image_loaded(default_img)
assert check_image_loaded(driver, default_img)

bmp_img_src = bmp_img.get_attribute("src")
assert bmp_img_src is not None
assert bmp_img_src.startswith("data:image/bmp;base64")
assert check_image_loaded(bmp_img)
assert check_image_loaded(driver, bmp_img)

jpg_img_src = jpg_img.get_attribute("src")
assert jpg_img_src is not None
assert jpg_img_src.startswith("data:image/jpeg;base64")
assert check_image_loaded(jpg_img)
assert check_image_loaded(driver, jpg_img)

png_img_src = png_img.get_attribute("src")
assert png_img_src is not None
assert png_img_src.startswith("data:image/png;base64")
assert check_image_loaded(png_img)
assert check_image_loaded(driver, png_img)

gif_img_src = gif_img.get_attribute("src")
assert gif_img_src is not None
assert gif_img_src.startswith("data:image/gif;base64")
assert check_image_loaded(gif_img)
assert check_image_loaded(driver, gif_img)

webp_img_src = webp_img.get_attribute("src")
assert webp_img_src is not None
assert webp_img_src.startswith("data:image/webp;base64")
assert check_image_loaded(webp_img)
assert check_image_loaded(driver, webp_img)

from_url_img_src = from_url_img.get_attribute("src")
assert from_url_img_src is not None
assert from_url_img_src.startswith("data:image/jpeg;base64")
assert check_image_loaded(
from_url_img,
check_width=" == 200",
check_height=" == 300",
)
assert check_image_loaded(driver, from_url_img, expected_height=300)

uploaded_img_src = uploaded_img.get_attribute("src")
assert uploaded_img_src is not None
assert "generated.png" in uploaded_img_src
assert check_image_loaded(driver, uploaded_img, 150, 150)
Loading
Loading