diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index bc306868f74..e3241c7bccc 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -273,6 +273,14 @@ def create(cls, *children, **props) -> Component: elif isinstance(on_drop, Callable): # Call the lambda to get the event chain. on_drop = call_event_fn(on_drop, _on_drop_spec) + if isinstance(on_drop, EventSpec): + # Update the provided args for direct use with on_drop. + on_drop = on_drop.with_args( + args=tuple( + cls._update_arg_tuple_for_on_drop(arg_value) + for arg_value in on_drop.args + ), + ) upload_props["on_drop"] = on_drop input_props_unique_name = get_unique_variable_name() diff --git a/tests/integration/test_upload.py b/tests/integration/test_upload.py index ca7689ff7c9..d4937af9414 100644 --- a/tests/integration/test_upload.py +++ b/tests/integration/test_upload.py @@ -26,27 +26,32 @@ def UploadFile(): class UploadState(rx.State): _file_data: dict[str, str] = {} event_order: rx.Field[list[str]] = rx.field([]) - progress_dicts: list[dict] = [] - disabled: bool = False - large_data: str = "" + progress_dicts: rx.Field[list[dict]] = rx.field([]) + disabled: rx.Field[bool] = rx.field(False) + large_data: rx.Field[str] = rx.field("") + quaternary_names: rx.Field[list[str]] = rx.field([]) + @rx.event async def handle_upload(self, files: list[rx.UploadFile]): for file in files: upload_data = await file.read() - self._file_data[file.filename or ""] = upload_data.decode("utf-8") + self._file_data[file.name or ""] = upload_data.decode("utf-8") + @rx.event async def handle_upload_secondary(self, files: list[rx.UploadFile]): for file in files: upload_data = await file.read() - self._file_data[file.filename or ""] = upload_data.decode("utf-8") + self._file_data[file.name or ""] = upload_data.decode("utf-8") self.large_data = LARGE_DATA yield UploadState.chain_event + @rx.event def upload_progress(self, progress): assert progress self.event_order.append("upload_progress") self.progress_dicts.append(progress) + @rx.event def chain_event(self): assert self.large_data == LARGE_DATA self.large_data = "" @@ -55,10 +60,14 @@ def chain_event(self): @rx.event async def handle_upload_tertiary(self, files: list[rx.UploadFile]): for file in files: - (rx.get_upload_dir() / (file.filename or "INVALID")).write_bytes( + (rx.get_upload_dir() / (file.name or "INVALID")).write_bytes( await file.read() ) + @rx.event + async def handle_upload_quaternary(self, files: list[rx.UploadFile]): + self.quaternary_names = [file.name for file in files if file.name] + @rx.event def do_download(self): return rx.download(rx.get_upload_url("test.txt")) @@ -80,7 +89,7 @@ def index(): ), rx.button( "Upload", - on_click=lambda: UploadState.handle_upload(rx.upload_files()), # pyright: ignore [reportCallIssue] + on_click=lambda: UploadState.handle_upload(rx.upload_files()), # pyright: ignore [reportArgumentType] id="upload_button", ), rx.box( @@ -105,8 +114,8 @@ def index(): ), rx.button( "Upload", - on_click=UploadState.handle_upload_secondary( # pyright: ignore [reportCallIssue] - rx.upload_files( + on_click=UploadState.handle_upload_secondary( + rx.upload_files( # pyright: ignore [reportArgumentType] upload_id="secondary", on_upload_progress=UploadState.upload_progress, ), @@ -163,6 +172,22 @@ def index(): on_click=UploadState.do_download, id="download-backend", ), + rx.upload.root( + rx.vstack( + rx.button("Select File"), + rx.text("Drag and drop files here or click to select files"), + ), + on_drop=UploadState.handle_upload_quaternary( + rx.upload_files( # pyright: ignore [reportArgumentType] + upload_id="quaternary", + ), + ), + id="quaternary", + ), + rx.text( + UploadState.quaternary_names.to_string(), + id="quaternary_files", + ), rx.text(UploadState.event_order.to_string(), id="event-order"), ) @@ -501,3 +526,50 @@ async def test_upload_download_file( download_backend.click() assert urlsplit(driver.current_url).path == f"/{Endpoint.UPLOAD.value}/test.txt" assert driver.find_element(by=By.TAG_NAME, value="body").text == exp_contents + + +@pytest.mark.asyncio +async def test_on_drop( + tmp_path, + upload_file: AppHarness, + driver: WebDriver, +): + """Test the on_drop event handler. + + Args: + tmp_path: pytest tmp_path fixture + upload_file: harness for UploadFile app. + driver: WebDriver instance. + """ + assert upload_file.app_instance is not None + token = poll_for_token(driver, upload_file) + full_state_name = upload_file.get_full_state_name(["_upload_state"]) + state_name = upload_file.get_state_name("_upload_state") + substate_token = f"{token}_{full_state_name}" + + upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[ + 3 + ] # quaternary upload + assert upload_box + + exp_name = "drop_test.txt" + exp_contents = "dropped file contents!" + target_file = tmp_path / exp_name + target_file.write_text(exp_contents) + + # Simulate file drop by directly setting the file input + upload_box.send_keys(str(target_file)) + + # Wait for the on_drop event to be processed + await asyncio.sleep(0.5) + + async def exp_name_in_quaternary(): + state = await upload_file.get_state(substate_token) + return exp_name in state.substates[state_name].quaternary_names + + # Poll until the file names appear in the display + await AppHarness._poll_for_async(exp_name_in_quaternary) + + # Verify through state that the file names were captured correctly + state = await upload_file.get_state(substate_token) + assert exp_name in state.substates[state_name].quaternary_names diff --git a/tests/units/states/upload.py b/tests/units/states/upload.py index 4e93cd35279..6942a430b4b 100644 --- a/tests/units/states/upload.py +++ b/tests/units/states/upload.py @@ -59,14 +59,14 @@ async def multi_handle_upload(self, files: list[rx.UploadFile]): """ for file in files: upload_data = await file.read() - assert file.filename is not None - outfile = self._tmp_path / file.filename + assert file.name is not None + outfile = self._tmp_path / file.name # Save the file. outfile.write_bytes(upload_data) # Update the img var. - self.img_list.append(file.filename) + self.img_list.append(file.name) @rx.event(background=True) async def bg_upload(self, files: list[rx.UploadFile]): @@ -106,14 +106,14 @@ async def multi_handle_upload(self, files: list[rx.UploadFile]): """ for file in files: upload_data = await file.read() - assert file.filename is not None - outfile = self._tmp_path / file.filename + assert file.name is not None + outfile = self._tmp_path / file.name # Save the file. outfile.write_bytes(upload_data) # Update the img var. - self.img_list.append(file.filename) + self.img_list.append(file.name) @rx.event(background=True) async def bg_upload(self, files: list[rx.UploadFile]): @@ -153,14 +153,14 @@ async def multi_handle_upload(self, files: list[rx.UploadFile]): """ for file in files: upload_data = await file.read() - assert file.filename is not None - outfile = self._tmp_path / file.filename + assert file.name is not None + outfile = self._tmp_path / file.name # Save the file. outfile.write_bytes(upload_data) # Update the img var. - self.img_list.append(file.filename) + self.img_list.append(file.name) @rx.event(background=True) async def bg_upload(self, files: list[rx.UploadFile]):