Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions reflex/components/core/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,14 @@ def create(cls, *children, **props) -> Component:
elif isinstance(on_drop, Callable):
# Call the lambda to get the event chain.
on_drop = call_event_fn(on_drop, _on_drop_spec)
if isinstance(on_drop, EventSpec):
# Update the provided args for direct use with on_drop.
on_drop = on_drop.with_args(
args=tuple(
cls._update_arg_tuple_for_on_drop(arg_value)
for arg_value in on_drop.args
),
)
upload_props["on_drop"] = on_drop

input_props_unique_name = get_unique_variable_name()
Expand Down
90 changes: 81 additions & 9 deletions tests/integration/test_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,32 @@ def UploadFile():
class UploadState(rx.State):
_file_data: dict[str, str] = {}
event_order: rx.Field[list[str]] = rx.field([])
progress_dicts: list[dict] = []
disabled: bool = False
large_data: str = ""
progress_dicts: rx.Field[list[dict]] = rx.field([])
disabled: rx.Field[bool] = rx.field(False)
large_data: rx.Field[str] = rx.field("")
quaternary_names: rx.Field[list[str]] = rx.field([])

@rx.event
async def handle_upload(self, files: list[rx.UploadFile]):
for file in files:
upload_data = await file.read()
self._file_data[file.filename or ""] = upload_data.decode("utf-8")
self._file_data[file.name or ""] = upload_data.decode("utf-8")

@rx.event
async def handle_upload_secondary(self, files: list[rx.UploadFile]):
for file in files:
upload_data = await file.read()
self._file_data[file.filename or ""] = upload_data.decode("utf-8")
self._file_data[file.name or ""] = upload_data.decode("utf-8")
self.large_data = LARGE_DATA
yield UploadState.chain_event

@rx.event
def upload_progress(self, progress):
assert progress
self.event_order.append("upload_progress")
self.progress_dicts.append(progress)

@rx.event
def chain_event(self):
assert self.large_data == LARGE_DATA
self.large_data = ""
Expand All @@ -55,10 +60,14 @@ def chain_event(self):
@rx.event
async def handle_upload_tertiary(self, files: list[rx.UploadFile]):
for file in files:
(rx.get_upload_dir() / (file.filename or "INVALID")).write_bytes(
(rx.get_upload_dir() / (file.name or "INVALID")).write_bytes(
await file.read()
)

@rx.event
async def handle_upload_quaternary(self, files: list[rx.UploadFile]):
self.quaternary_names = [file.name for file in files if file.name]

@rx.event
def do_download(self):
return rx.download(rx.get_upload_url("test.txt"))
Expand All @@ -80,7 +89,7 @@ def index():
),
rx.button(
"Upload",
on_click=lambda: UploadState.handle_upload(rx.upload_files()), # pyright: ignore [reportCallIssue]
on_click=lambda: UploadState.handle_upload(rx.upload_files()), # pyright: ignore [reportArgumentType]
id="upload_button",
),
rx.box(
Expand All @@ -105,8 +114,8 @@ def index():
),
rx.button(
"Upload",
on_click=UploadState.handle_upload_secondary( # pyright: ignore [reportCallIssue]
rx.upload_files(
on_click=UploadState.handle_upload_secondary(
rx.upload_files( # pyright: ignore [reportArgumentType]
upload_id="secondary",
on_upload_progress=UploadState.upload_progress,
),
Expand Down Expand Up @@ -163,6 +172,22 @@ def index():
on_click=UploadState.do_download,
id="download-backend",
),
rx.upload.root(
rx.vstack(
rx.button("Select File"),
rx.text("Drag and drop files here or click to select files"),
),
on_drop=UploadState.handle_upload_quaternary(
rx.upload_files( # pyright: ignore [reportArgumentType]
upload_id="quaternary",
),
),
id="quaternary",
),
rx.text(
UploadState.quaternary_names.to_string(),
id="quaternary_files",
),
rx.text(UploadState.event_order.to_string(), id="event-order"),
)

Expand Down Expand Up @@ -501,3 +526,50 @@ async def test_upload_download_file(
download_backend.click()
assert urlsplit(driver.current_url).path == f"/{Endpoint.UPLOAD.value}/test.txt"
assert driver.find_element(by=By.TAG_NAME, value="body").text == exp_contents


@pytest.mark.asyncio
async def test_on_drop(
tmp_path,
upload_file: AppHarness,
driver: WebDriver,
):
"""Test the on_drop event handler.

Args:
tmp_path: pytest tmp_path fixture
upload_file: harness for UploadFile app.
driver: WebDriver instance.
"""
assert upload_file.app_instance is not None
token = poll_for_token(driver, upload_file)
full_state_name = upload_file.get_full_state_name(["_upload_state"])
state_name = upload_file.get_state_name("_upload_state")
substate_token = f"{token}_{full_state_name}"

upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[
3
] # quaternary upload
assert upload_box

exp_name = "drop_test.txt"
exp_contents = "dropped file contents!"
target_file = tmp_path / exp_name
target_file.write_text(exp_contents)

# Simulate file drop by directly setting the file input
upload_box.send_keys(str(target_file))

# Wait for the on_drop event to be processed
await asyncio.sleep(0.5)

async def exp_name_in_quaternary():
state = await upload_file.get_state(substate_token)
return exp_name in state.substates[state_name].quaternary_names

# Poll until the file names appear in the display
await AppHarness._poll_for_async(exp_name_in_quaternary)

# Verify through state that the file names were captured correctly
state = await upload_file.get_state(substate_token)
assert exp_name in state.substates[state_name].quaternary_names
18 changes: 9 additions & 9 deletions tests/units/states/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ async def multi_handle_upload(self, files: list[rx.UploadFile]):
"""
for file in files:
upload_data = await file.read()
assert file.filename is not None
outfile = self._tmp_path / file.filename
assert file.name is not None
outfile = self._tmp_path / file.name

# Save the file.
outfile.write_bytes(upload_data)

# Update the img var.
self.img_list.append(file.filename)
self.img_list.append(file.name)

@rx.event(background=True)
async def bg_upload(self, files: list[rx.UploadFile]):
Expand Down Expand Up @@ -106,14 +106,14 @@ async def multi_handle_upload(self, files: list[rx.UploadFile]):
"""
for file in files:
upload_data = await file.read()
assert file.filename is not None
outfile = self._tmp_path / file.filename
assert file.name is not None
outfile = self._tmp_path / file.name

# Save the file.
outfile.write_bytes(upload_data)

# Update the img var.
self.img_list.append(file.filename)
self.img_list.append(file.name)

@rx.event(background=True)
async def bg_upload(self, files: list[rx.UploadFile]):
Expand Down Expand Up @@ -153,14 +153,14 @@ async def multi_handle_upload(self, files: list[rx.UploadFile]):
"""
for file in files:
upload_data = await file.read()
assert file.filename is not None
outfile = self._tmp_path / file.filename
assert file.name is not None
outfile = self._tmp_path / file.name

# Save the file.
outfile.write_bytes(upload_data)

# Update the img var.
self.img_list.append(file.filename)
self.img_list.append(file.name)

@rx.event(background=True)
async def bg_upload(self, files: list[rx.UploadFile]):
Expand Down
Loading