diff --git a/pyi_hashes.json b/pyi_hashes.json index 8f998b4650f..8a98f576ff8 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -42,7 +42,7 @@ "reflex/components/lucide/icon.pyi": "775e6686e491fd46f28a00b19699db3d", "reflex/components/markdown/markdown.pyi": "73d3116fa28450c90f25b21107285daa", "reflex/components/moment/moment.pyi": "ab1d6618159693014fdf22b4aa84c877", - "reflex/components/next/base.pyi": "5ea32ecae5c64e02217c8895783f9ccb", + "reflex/components/next/base.pyi": "5e75245c2b0ee4715f89efaf42d101d8", "reflex/components/next/image.pyi": "8c305c03019d37c07560c154a05bf5dd", "reflex/components/next/link.pyi": "cc438e48a9f31bf16f1cdb6e16017477", "reflex/components/next/video.pyi": "8f5694a4a2118c5297e2eba479b6f018", @@ -52,68 +52,68 @@ "reflex/components/radix/primitives/accordion.pyi": "a31599f0b2a1a69a10917137dcb75a9d", "reflex/components/radix/primitives/base.pyi": "fc910c9bd364b57e1c092fbf8889158d", "reflex/components/radix/primitives/drawer.pyi": "8f20bac0e36266398be1a124218bda87", - "reflex/components/radix/primitives/form.pyi": "11402dfac6256f2220c5c830008b8b8b", + "reflex/components/radix/primitives/form.pyi": "efd2ec67535eb1b1eefaafc0d5e36d8a", "reflex/components/radix/primitives/progress.pyi": "98b4add410a80a353ab503ad577169c2", "reflex/components/radix/primitives/slider.pyi": "573837a7d8d90deaf57c911faffed254", "reflex/components/radix/themes/__init__.pyi": "a15f9464ad99f248249ffa8e6deea4cf", "reflex/components/radix/themes/base.pyi": "526db93a3f52bb00ad220f8744eba797", "reflex/components/radix/themes/color_mode.pyi": "f7515dccd1e315dc28a3cbbe2eabe7ff", "reflex/components/radix/themes/components/__init__.pyi": "87bb9ffff641928562da1622d2ca5993", - "reflex/components/radix/themes/components/alert_dialog.pyi": "9f19bcdb4588a7f76596d142a0ac0950", - "reflex/components/radix/themes/components/aspect_ratio.pyi": "ecace271fa2c518c429594556ddf4389", - "reflex/components/radix/themes/components/avatar.pyi": "51d3f65fb3e5c4abda00cc8bf4a7e50c", - "reflex/components/radix/themes/components/badge.pyi": "1ecf1253abb3a7e293146d4cc6327ceb", - "reflex/components/radix/themes/components/button.pyi": "70b5258eb4c2716af39f1b2e5bfc4cbb", - "reflex/components/radix/themes/components/callout.pyi": "aa9d08f1246d9c7f97ad6a3ac4d5fcb5", - "reflex/components/radix/themes/components/card.pyi": "60374dee8093535874fac2901d993aaf", - "reflex/components/radix/themes/components/checkbox.pyi": "0766d08ef379dd919134ff22481528c6", - "reflex/components/radix/themes/components/checkbox_cards.pyi": "7cb7297d3e3388efbd2b678278bb034b", - "reflex/components/radix/themes/components/checkbox_group.pyi": "0878853ed682b3930fbf0c4f0a655ba2", - "reflex/components/radix/themes/components/context_menu.pyi": "4f64ded6e04727c9d24ef2518f9db540", - "reflex/components/radix/themes/components/data_list.pyi": "a07a9e89e0fb3f10db78549029fecb37", - "reflex/components/radix/themes/components/dialog.pyi": "8b9725b561c253b37562279ce94a99e9", - "reflex/components/radix/themes/components/dropdown_menu.pyi": "1a0bdafb4fa95044c8edcc9e83efacf5", - "reflex/components/radix/themes/components/hover_card.pyi": "f15aedcd77ce8a7ab7f7470780fe4035", - "reflex/components/radix/themes/components/icon_button.pyi": "3887d4225f5ead440e8aeecceec990fd", - "reflex/components/radix/themes/components/inset.pyi": "3dbda9fbe5f660c8bfda717aceb0dbdc", - "reflex/components/radix/themes/components/popover.pyi": "bf2cd9e744a23305b74ff888d980993f", - "reflex/components/radix/themes/components/progress.pyi": "a5610ee8a8eab36b1aada37e866f9494", - "reflex/components/radix/themes/components/radio.pyi": "69f5c47aee9a1179c273a4e4765c6099", - "reflex/components/radix/themes/components/radio_cards.pyi": "6f323c60aff4da0f576655c32d208bb8", - "reflex/components/radix/themes/components/radio_group.pyi": "4d9d918832555a5fa3efa4a71df15ad2", - "reflex/components/radix/themes/components/scroll_area.pyi": "7b507e661c87b08061df4e13e73ab47b", - "reflex/components/radix/themes/components/segmented_control.pyi": "a848ceda014c4f64a1adc89202598c15", - "reflex/components/radix/themes/components/select.pyi": "b223797edc8b9d3341c105c796d392de", - "reflex/components/radix/themes/components/separator.pyi": "92c789575a1336bb3e5dcd2012fb68a1", - "reflex/components/radix/themes/components/skeleton.pyi": "34340e43123c2aaa89f042411cae06ec", - "reflex/components/radix/themes/components/slider.pyi": "023bc8fada28779c0d2f8f14f8b30fec", - "reflex/components/radix/themes/components/spinner.pyi": "941dfcee9581f116af7c7116084a6938", - "reflex/components/radix/themes/components/switch.pyi": "9e3dfd7dfa16166bb2adc5fb60b25438", - "reflex/components/radix/themes/components/table.pyi": "5643313daebc43bc6246d0beee81505f", - "reflex/components/radix/themes/components/tabs.pyi": "0bc64cfc23592767477af649339f0e4e", - "reflex/components/radix/themes/components/text_area.pyi": "8d976ea7e23b0f5942aeb3a0d295835c", - "reflex/components/radix/themes/components/text_field.pyi": "77b28c7caebea3fdb4ba5ef6f3ce19f3", - "reflex/components/radix/themes/components/tooltip.pyi": "050ecd7a591e358170d332b1e9c07059", + "reflex/components/radix/themes/components/alert_dialog.pyi": "3832f3e8a6a3eed1bfa969efea627b72", + "reflex/components/radix/themes/components/aspect_ratio.pyi": "f90aa46ef8b29bd076d98321de96315a", + "reflex/components/radix/themes/components/avatar.pyi": "d40e8e25a9c007f2554590abd116a095", + "reflex/components/radix/themes/components/badge.pyi": "422c4d1586e6b22d00a2d5f002989651", + "reflex/components/radix/themes/components/button.pyi": "fc5c290d6df9b5197c65036c9edafa38", + "reflex/components/radix/themes/components/callout.pyi": "f81f5032d90e36705fd5b5ba30d3f3ab", + "reflex/components/radix/themes/components/card.pyi": "fdf71624bdeeba391d1c0545039dd2e7", + "reflex/components/radix/themes/components/checkbox.pyi": "006614845a236f6611c656c05a8db394", + "reflex/components/radix/themes/components/checkbox_cards.pyi": "289d0fd448f654e17f3132234e9d4983", + "reflex/components/radix/themes/components/checkbox_group.pyi": "8eb8cca3e0c5885150576cc60ba19fea", + "reflex/components/radix/themes/components/context_menu.pyi": "5f178adef09c0f36103e33b11326cb2a", + "reflex/components/radix/themes/components/data_list.pyi": "4014ea23eec39cfe98032852665a92ca", + "reflex/components/radix/themes/components/dialog.pyi": "8dc1a09d30aff2fcf28e28988fca170f", + "reflex/components/radix/themes/components/dropdown_menu.pyi": "0cd87cddbe9a83dcfa9cbcc4f9d98dd4", + "reflex/components/radix/themes/components/hover_card.pyi": "973a4911f68cec60f40a8e2ca5e42770", + "reflex/components/radix/themes/components/icon_button.pyi": "55e0b8c8233d1e5a52a8c09c959e4989", + "reflex/components/radix/themes/components/inset.pyi": "a215de3b29b2133626cbfc83544305fe", + "reflex/components/radix/themes/components/popover.pyi": "1fa6f96aef6f148f110fa208aa449ab7", + "reflex/components/radix/themes/components/progress.pyi": "1d0f827e8db089418b2786f82d55512d", + "reflex/components/radix/themes/components/radio.pyi": "a8fcc63bf42129196d70a9571647d4bb", + "reflex/components/radix/themes/components/radio_cards.pyi": "b5222b86e418920de2ef988752f0b577", + "reflex/components/radix/themes/components/radio_group.pyi": "7b95ee1fcd41186f2c7670be273d134f", + "reflex/components/radix/themes/components/scroll_area.pyi": "28352b03135ef2065876a5199b9c150a", + "reflex/components/radix/themes/components/segmented_control.pyi": "0477ee74033ed0f67cd2cb94a47ccea9", + "reflex/components/radix/themes/components/select.pyi": "9c63eb11bab2d2913431ec0c13111b6d", + "reflex/components/radix/themes/components/separator.pyi": "f8c9c18ea7f67e8287f4ebc5c09790b5", + "reflex/components/radix/themes/components/skeleton.pyi": "aeff3cbc53989c4824a5e49e9ea3bbca", + "reflex/components/radix/themes/components/slider.pyi": "242e107d73ec14d984cb88fa8f23ad68", + "reflex/components/radix/themes/components/spinner.pyi": "5050ba710b0c950c29f69cafd93f6c4f", + "reflex/components/radix/themes/components/switch.pyi": "61729a28148bc17acd20e48c12f60a54", + "reflex/components/radix/themes/components/table.pyi": "81c77cecf78ddb3e931c9a5f0f8eccde", + "reflex/components/radix/themes/components/tabs.pyi": "6facf7ebd344f8995934a167af01a9e5", + "reflex/components/radix/themes/components/text_area.pyi": "eef90fcc66990c44f3c0540862877cba", + "reflex/components/radix/themes/components/text_field.pyi": "92552297cc747dd3aae6f382699e319d", + "reflex/components/radix/themes/components/tooltip.pyi": "5e17b67e50410f1124d2150237eab7cf", "reflex/components/radix/themes/layout/__init__.pyi": "9a52c5b283c864be70b51a8fd6120392", - "reflex/components/radix/themes/layout/base.pyi": "a3a869acd2a1c5025580697ae5e2c024", - "reflex/components/radix/themes/layout/box.pyi": "d2d2b266eed53e866c5b5ad8cee292e4", + "reflex/components/radix/themes/layout/base.pyi": "6a255a392bf0d54c924c26e673248971", + "reflex/components/radix/themes/layout/box.pyi": "731cc26fc41d2b174ed4e901f5292479", "reflex/components/radix/themes/layout/center.pyi": "e0592f33bdec5586a7377ca986f1a966", - "reflex/components/radix/themes/layout/container.pyi": "691ec3a849be5f42c0b5d6ba1b243b55", - "reflex/components/radix/themes/layout/flex.pyi": "ed2746b5cd2b3d9ef73e370f85a66043", - "reflex/components/radix/themes/layout/grid.pyi": "6543e4413501fd41a20ff4d58931b584", + "reflex/components/radix/themes/layout/container.pyi": "3c5ddf03873da9bf0f5308d5d6429097", + "reflex/components/radix/themes/layout/flex.pyi": "8d8cfd4f00e21aac8d165ded0f7c600f", + "reflex/components/radix/themes/layout/grid.pyi": "412f164266f810671cf38ca5e50d9cfd", "reflex/components/radix/themes/layout/list.pyi": "32ce23a3f851698ac0d609e616bd3605", - "reflex/components/radix/themes/layout/section.pyi": "2b9b826ab42eae3f8cf4d1899dea4b33", + "reflex/components/radix/themes/layout/section.pyi": "2904116ccc24dcb66285ff2daaac1875", "reflex/components/radix/themes/layout/spacer.pyi": "3def4df36e8eecdfba0a7d2f1890b908", - "reflex/components/radix/themes/layout/stack.pyi": "1b09d9123358d430ad6c66343d0e9c92", + "reflex/components/radix/themes/layout/stack.pyi": "b7ec458d254cd09058ca805d553199da", "reflex/components/radix/themes/typography/__init__.pyi": "ef0ba71353dcac1f3546de45f8721bae", - "reflex/components/radix/themes/typography/blockquote.pyi": "04de9fdb22583d87faaba5619bdc6e3e", - "reflex/components/radix/themes/typography/code.pyi": "bd58d40878c3488f1ba58a122e78f4e7", - "reflex/components/radix/themes/typography/heading.pyi": "91bfc9176f7e9ef33d1f69711ceddbe1", - "reflex/components/radix/themes/typography/link.pyi": "febffdd31eee7a4f67d12d6e10a13516", - "reflex/components/radix/themes/typography/text.pyi": "d2ba2f718acd0eaf7b5923fe6a27d59c", - "reflex/components/react_player/audio.pyi": "bd7e024d39ac641f8279ee0f6afd7985", + "reflex/components/radix/themes/typography/blockquote.pyi": "fdd2214a8416bcd4ba644a0bd0015c5a", + "reflex/components/radix/themes/typography/code.pyi": "2e0b487ed1128422bfc4105928dbb18a", + "reflex/components/radix/themes/typography/heading.pyi": "bec5af8f72e3c0a764d77e16608da4a1", + "reflex/components/radix/themes/typography/link.pyi": "196d6ef6c1a15f2d7180a973e8753ea5", + "reflex/components/radix/themes/typography/text.pyi": "33f91de2a0ae94e5802e7c8f0971b1df", + "reflex/components/react_player/audio.pyi": "231e9338b19330a6963928f7e90cb40f", "reflex/components/react_player/react_player.pyi": "40db798bcb7fa40207d24f49722135ae", - "reflex/components/react_player/video.pyi": "22d84a7f57be13ece90cb30536d76c7d", + "reflex/components/react_player/video.pyi": "f92885d49cdc565b95b20820d09e2ca2", "reflex/components/recharts/__init__.pyi": "a060a4abcd018165bc499173e723cf9e", "reflex/components/recharts/cartesian.pyi": "601e1acb0ad6bd93ce371d763220aabe", "reflex/components/recharts/charts.pyi": "2f0a39f9c02de83d9e2d97763b4411af", diff --git a/pyproject.toml b/pyproject.toml index b712aa49612..509ed60439f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,50 +91,51 @@ lint.ignore = [ "ANN2", "ANN4", "ARG", - "ASYNC", - "B008", "BLE", "C901", "COM", "D205", "DTZ", "E501", - "EM", "F403", "FBT", "FIX", - "FLY", "G004", - "INP", "ISC003", - "NPY", - "PD", - "PIE", "PLC", "PLR", "PLW", - "PT", + "PT011", + "PT012", "PYI", - "RET", - "RSE", - "RUF006", - "RUF008", "RUF012", "S", - "SIM115", "SLF", "SLOT", "TC", "TD", - "TID", "TRY0", "UP038", ] lint.pydocstyle.convention = "google" +lint.flake8-bugbear.extend-immutable-calls = [ + "reflex.utils.types.Unset", + "reflex.vars.base.Var.create", +] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] -"tests/*.py" = ["ANN001", "D100", "D103", "D104", "B018", "PERF", "T", "N"] +"tests/*.py" = [ + "ANN001", + "D100", + "D103", + "D104", + "INP", + "B018", + "PERF", + "T", + "N", +] "benchmarks/*.py" = ["ANN001", "D100", "D103", "D104", "B018", "PERF", "T", "N"] "reflex/.templates/*.py" = ["D100", "D103", "D104"] "*.pyi" = ["D301", "D415", "D417", "D418", "E742", "N", "PGH"] diff --git a/reflex/.templates/apps/blank/code/blank.py b/reflex/.templates/apps/blank/code/blank.py index c0de6e44fe1..08113eeffe7 100644 --- a/reflex/.templates/apps/blank/code/blank.py +++ b/reflex/.templates/apps/blank/code/blank.py @@ -8,8 +8,6 @@ class State(rx.State): """The app state.""" - ... - def index() -> rx.Component: # Welcome Page (Index) diff --git a/reflex/app.py b/reflex/app.py index 5bcc4f9f6b8..bee5e864ee9 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -229,8 +229,6 @@ def default_error_boundary(*children: Component) -> Component: class OverlayFragment(Fragment): """Alias for Fragment, used to wrap the overlay_component.""" - pass - @dataclasses.dataclass(frozen=True) class UploadFile(StarletteUploadFile): @@ -262,6 +260,7 @@ def name(self) -> str | None: """ if self.path: return self.path.name + return None @property def filename(self) -> str | None: @@ -481,9 +480,8 @@ def __post_init__(self): # Special case to allow test cases have multiple subclasses of rx.BaseState. if not is_testing_env() and BaseState.__subclasses__() != [State]: # Only rx.State is allowed as Base State subclass. - raise ValueError( - "rx.BaseState cannot be subclassed directly. Use rx.State instead" - ) + msg = "rx.BaseState cannot be subclassed directly. Use rx.State instead" + raise ValueError(msg) get_config(reload=True) @@ -552,9 +550,8 @@ def _setup_state(self) -> None: transports=["websocket"], ) elif getattr(self.sio, "async_mode", "") != "asgi": - raise RuntimeError( - f"Custom `sio` must use `async_mode='asgi'`, not '{self.sio.async_mode}'." - ) + msg = f"Custom `sio` must use `async_mode='asgi'`, not '{self.sio.async_mode}'." + raise RuntimeError(msg) # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path. socket_app = EngineIOApp(self.sio, socketio_path="") @@ -633,7 +630,8 @@ def callback(f: concurrent.futures.Future): compile_future.result() if not self._api: - raise ValueError("The app has not been initialized.") + msg = "The app has not been initialized." + raise ValueError(msg) if self._cached_fastapi_app is not None: asgi_app = self._cached_fastapi_app @@ -741,7 +739,8 @@ def state_manager(self) -> StateManager: ValueError: if the state has not been initialized. """ if self._state_manager is None: - raise ValueError("The state manager has not been initialized.") + msg = "The state manager has not been initialized." + raise ValueError(msg) return self._state_manager @staticmethod @@ -791,9 +790,8 @@ def add_page( # If the route is not set, get it from the callable. if route is None: if not isinstance(component, Callable): - raise exceptions.RouteValueError( - "Route must be set if component is not a callable." - ) + msg = "Route must be set if component is not a callable." + raise exceptions.RouteValueError(msg) # Format the route. route = format.format_route(component.__name__) else: @@ -808,9 +806,8 @@ def add_page( image = image or constants.Page404.IMAGE else: if component is None: - raise exceptions.PageValueError( - "Component must be set for a non-404 page." - ) + msg = "Component must be set for a non-404 page." + raise exceptions.PageValueError(msg) # Check if the route given is valid verify_route_validity(route) @@ -841,11 +838,12 @@ def add_page( else f"`{route}`" ) existing_component = self._unevaluated_pages[route].component - raise exceptions.RouteValueError( + msg = ( f"Tried to add page {readable_name_from_component(component)} with route {route_name} but " f"page {readable_name_from_component(existing_component)} with the same route already exists. " "Make sure you do not have two pages with the same route." ) + raise exceptions.RouteValueError(msg) # Setup dynamic args for the route. # this state assignment is only required for tests using the deprecated state kwarg for App @@ -930,10 +928,9 @@ def _check_routes_conflict(self, new_route: str): ): if rw in segments and r != nr: # If the slugs in the segments of both routes are not the same, then the route is invalid - raise RouteValueError( - f"You cannot use different slug names for the same dynamic path in {route} and {new_route} ('{r}' != '{nr}')" - ) - elif rw not in segments and r != nr: + msg = f"You cannot use different slug names for the same dynamic path in {route} and {new_route} ('{r}' != '{nr}')" + raise RouteValueError(msg) + if rw not in segments and r != nr: # if the section being compared in both routes is not a dynamic segment(i.e not wrapped in brackets) # then we are guaranteed that the route is valid and there's no need checking the rest. # eg. /posts/[id]/info/[slug1] and /posts/[id]/info1/[slug1] is always going to be valid since @@ -1088,9 +1085,7 @@ def _add_overlay_to_component(self, component: Component) -> Component: return component # recreate OverlayFragment with overlay_component as first child - component = OverlayFragment.create(overlay_component, *children) - - return component + return OverlayFragment.create(overlay_component, *children) def _setup_overlay_component(self): """If a State is not used and no overlay_component is specified, do not render the connection modal.""" @@ -1152,9 +1147,8 @@ def _validate_var_dependencies(self, state: type[BaseState] | None = None) -> No ) for dep in dep_set: if dep not in state_cls.vars and dep not in state_cls.backend_vars: - raise exceptions.VarDependencyError( - f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {state_name}.{dep}" - ) + msg = f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {state_name}.{dep}" + raise exceptions.VarDependencyError(msg) for substate in state.class_subclasses: self._validate_var_dependencies(substate) @@ -1346,10 +1340,11 @@ def memoized_toast_provider(): # Catch "static" apps (that do not define a rx.State subclass) which are trying to access rx.State. if code_uses_state_contexts(stateful_components_code) and self._state is None: - raise ReflexRuntimeError( + msg = ( "To access rx.State in frontend components, at least one " "subclass of rx.State must be defined in the app." ) + raise ReflexRuntimeError(msg) compile_results.append((stateful_components_path, stateful_components_code)) progress.advance(task) @@ -1543,9 +1538,8 @@ def _submit_work_without_advancing( if path.exists(): file_content = path.read_text() else: - raise FileNotFoundError( - f"Plugin {plugin_name} is trying to modify {path} but it does not exist." - ) + msg = f"Plugin {plugin_name} is trying to modify {path} but it does not exist." + raise FileNotFoundError(msg) output_mapping[path] = modify_fn(file_content) with console.timing("Write to Disk"): @@ -1588,7 +1582,8 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: RuntimeError: If the app has not been initialized yet. """ if self.event_namespace is None: - raise RuntimeError("App has not been initialized yet.") + msg = "App has not been initialized yet." + raise RuntimeError(msg) # Get exclusive access to the state. async with self.state_manager.modify_state(token) as state: @@ -1627,7 +1622,8 @@ async def _coro(): RuntimeError: If the app has not been initialized yet. """ if self.event_namespace is None: - raise RuntimeError("App has not been initialized yet.") + msg = "App has not been initialized yet." + raise RuntimeError(msg) # Process the event. async for update in state._process_event( @@ -1678,20 +1674,17 @@ def _validate_exception_handlers(self): _fn_name = type(handler_fn).__name__ if isinstance(handler_fn, functools.partial): - raise ValueError( - f"Provided custom {handler_domain} exception handler `{_fn_name}` is a partial function. Please provide a named function instead." - ) + msg = f"Provided custom {handler_domain} exception handler `{_fn_name}` is a partial function. Please provide a named function instead." + raise ValueError(msg) if not callable(handler_fn): - raise ValueError( - f"Provided custom {handler_domain} exception handler `{_fn_name}` is not a function." - ) + msg = f"Provided custom {handler_domain} exception handler `{_fn_name}` is not a function." + raise ValueError(msg) # Allow named functions only as lambda functions cannot be introspected if _fn_name == "": - raise ValueError( - f"Provided custom {handler_domain} exception handler `{_fn_name}` is a lambda function. Please use a named function instead." - ) + msg = f"Provided custom {handler_domain} exception handler `{_fn_name}` is a lambda function. Please use a named function instead." + raise ValueError(msg) # Check if the function has the necessary annotations and types in the right order argspec = inspect.getfullargspec(handler_fn) @@ -1703,22 +1696,21 @@ def _validate_exception_handlers(self): for required_arg_index, required_arg in enumerate(handler_spec): if required_arg not in arg_annotations: - raise ValueError( - f"Provided custom {handler_domain} exception handler `{_fn_name}` does not take the required argument `{required_arg}`" - ) - elif ( - not list(arg_annotations.keys())[required_arg_index] == required_arg - ): - raise ValueError( + msg = f"Provided custom {handler_domain} exception handler `{_fn_name}` does not take the required argument `{required_arg}`" + raise ValueError(msg) + if list(arg_annotations.keys())[required_arg_index] != required_arg: + msg = ( f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong argument order." f"Expected `{required_arg}` as the {required_arg_index + 1} argument but got `{list(arg_annotations.keys())[required_arg_index]}`" ) + raise ValueError(msg) if not issubclass(arg_annotations[required_arg], Exception): - raise ValueError( + msg = ( f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong type for {required_arg} argument." f"Expected to be `Exception` but got `{arg_annotations[required_arg]}`" ) + raise ValueError(msg) # Check if the return type is valid for backend exception handler if handler_domain == "backend": @@ -1738,10 +1730,11 @@ def _validate_exception_handlers(self): ) if not valid: - raise ValueError( + msg = ( f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong return type." f"Expected `EventSpec | list[EventSpec] | None` but got `{return_type}`" ) + raise ValueError(msg) async def process( @@ -1910,7 +1903,8 @@ async def upload_file(request: Request): return Response() # user cancelled files = files.getlist("files") if not files: - raise UploadValueError("No files were uploaded.") + msg = "No files were uploaded." + raise UploadValueError(msg) token = request.headers.get("reflex-client-token") handler = request.headers.get("reflex-event-handler") @@ -1937,9 +1931,8 @@ async def upload_file(request: Request): # check if there exists any handler args with annotation, list[UploadFile] if isinstance(func, EventHandler): if func.is_background: - raise UploadTypeError( - f"@rx.event(background=True) is not supported for upload handler `{handler}`.", - ) + msg = f"@rx.event(background=True) is not supported for upload handler `{handler}`." + raise UploadTypeError(msg) func = func.fn if isinstance(func, functools.partial): func = func.func @@ -1952,10 +1945,11 @@ async def upload_file(request: Request): break if not handler_upload_param: - raise UploadValueError( + msg = ( f"`{handler}` handler should have a parameter annotated as " "list[rx.UploadFile]" ) + raise UploadValueError(msg) # Make a copy of the files as they are closed after the request. # This behaviour changed from fastapi 0.103.0 to 0.103.1 as the @@ -2098,32 +2092,31 @@ async def on_event(self, sid: str, data: Any): try: fields = json.loads(fields) except json.JSONDecodeError as ex: - raise exceptions.EventDeserializationError( - f"Failed to deserialize event data: {fields}." - ) from ex + msg = f"Failed to deserialize event data: {fields}." + raise exceptions.EventDeserializationError(msg) from ex if not isinstance(fields, dict): - raise exceptions.EventDeserializationError( - f"Event data must be a dictionary, but received {fields} of type {type(fields)}." - ) + msg = f"Event data must be a dictionary, but received {fields} of type {type(fields)}." + raise exceptions.EventDeserializationError(msg) try: # Get the event. event = Event(**{k: v for k, v in fields.items() if k in _EVENT_FIELDS}) except (TypeError, ValueError) as ex: - raise exceptions.EventDeserializationError( - f"Failed to deserialize event data: {fields}." - ) from ex + msg = f"Failed to deserialize event data: {fields}." + raise exceptions.EventDeserializationError(msg) from ex self.token_to_sid[event.token] = sid self.sid_to_token[sid] = event.token # Get the event environment. if self.app.sio is None: - raise RuntimeError("Socket.IO is not initialized.") + msg = "Socket.IO is not initialized." + raise RuntimeError(msg) environ = self.app.sio.get_environ(sid, self.namespace) if environ is None: - raise RuntimeError("Socket.IO environ is not initialized.") + msg = "Socket.IO environ is not initialized." + raise RuntimeError(msg) # Get the client headers. headers = { diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index 26ebd934c94..44e3a884280 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -67,9 +67,8 @@ def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): InvalidLifespanTaskTypeError: If the task is a generator function. """ if inspect.isgeneratorfunction(task) or inspect.isasyncgenfunction(task): - raise InvalidLifespanTaskTypeError( - f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager." - ) + msg = f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager." + raise InvalidLifespanTaskTypeError(msg) if task_kwargs: original_task = task diff --git a/reflex/app_mixins/middleware.py b/reflex/app_mixins/middleware.py index 0099d8c98e7..b78b96ec2dd 100644 --- a/reflex/app_mixins/middleware.py +++ b/reflex/app_mixins/middleware.py @@ -56,6 +56,7 @@ async def _preprocess(self, state: BaseState, event: Event) -> StateUpdate | Non out = await out if out is not None: return out + return None async def _postprocess( self, state: BaseState, event: Event, update: StateUpdate diff --git a/reflex/app_mixins/mixin.py b/reflex/app_mixins/mixin.py index 23207a46292..802cee38c1a 100644 --- a/reflex/app_mixins/mixin.py +++ b/reflex/app_mixins/mixin.py @@ -12,4 +12,3 @@ def _init_mixin(self): Any App mixin can override this method to perform any initialization. """ - ... diff --git a/reflex/assets.py b/reflex/assets.py index cab3485b248..05b00f13290 100644 --- a/reflex/assets.py +++ b/reflex/assets.py @@ -58,9 +58,11 @@ def asset( cwd = Path.cwd() src_file_local = cwd / assets / path if subfolder is not None: - raise ValueError("Subfolder is not supported for local assets.") + msg = "Subfolder is not supported for local assets." + raise ValueError(msg) if not backend_only and not src_file_local.exists(): - raise FileNotFoundError(f"File not found: {src_file_local}") + msg = f"File not found: {src_file_local}" + raise FileNotFoundError(msg) return f"/{path}" # Shared asset handling @@ -73,7 +75,8 @@ def asset( external = constants.Dirs.EXTERNAL_APP_ASSETS src_file_shared = Path(calling_file).parent / path if not src_file_shared.exists(): - raise FileNotFoundError(f"File not found: {src_file_shared}") + msg = f"File not found: {src_file_shared}" + raise FileNotFoundError(msg) caller_module_path = module.__name__.replace(".", "/") subfolder = f"{caller_module_path}/{subfolder}" if subfolder else caller_module_path diff --git a/reflex/base.py b/reflex/base.py index a3425312668..6ad0986c31a 100644 --- a/reflex/base.py +++ b/reflex/base.py @@ -30,10 +30,11 @@ def validate_field_name(bases: list[type[BaseModel]], field_name: str) -> None: if not reload and getattr(base, field_name, None): pass except TypeError as te: - raise VarNameError( + msg = ( f'State var "{field_name}" in {base} has been shadowed by a substate var; ' f'use a different field name instead".' - ) from te + ) + raise VarNameError(msg) from te # monkeypatch pydantic validate_field_name method to skip validating diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index fcedd2dfff8..3833e26cb3e 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -201,15 +201,14 @@ def _validate_stylesheet(stylesheet_full_path: Path, assets_app_path: Path) -> N """ suffix = stylesheet_full_path.suffix[1:] if stylesheet_full_path.suffix else "" if suffix not in constants.Reflex.STYLESHEETS_SUPPORTED: - raise ValueError(f"Stylesheet file {stylesheet_full_path} is not supported.") + msg = f"Stylesheet file {stylesheet_full_path} is not supported." + raise ValueError(msg) if not stylesheet_full_path.absolute().is_relative_to(assets_app_path.absolute()): - raise FileNotFoundError( - f"Cannot include stylesheets from outside the assets directory: {stylesheet_full_path}" - ) + msg = f"Cannot include stylesheets from outside the assets directory: {stylesheet_full_path}" + raise FileNotFoundError(msg) if not stylesheet_full_path.name: - raise ValueError( - f"Stylesheet file name cannot be empty: {stylesheet_full_path}" - ) + msg = f"Stylesheet file name cannot be empty: {stylesheet_full_path}" + raise ValueError(msg) if ( len( stylesheet_full_path.absolute() @@ -219,9 +218,8 @@ def _validate_stylesheet(stylesheet_full_path: Path, assets_app_path: Path) -> N == 1 and stylesheet_full_path.stem == PageNames.STYLESHEET_ROOT ): - raise ValueError( - f"Stylesheet file name cannot be '{PageNames.STYLESHEET_ROOT}': {stylesheet_full_path}" - ) + msg = f"Stylesheet file name cannot be '{PageNames.STYLESHEET_ROOT}': {stylesheet_full_path}" + raise ValueError(msg) RADIX_THEMES_STYLESHEET = "@radix-ui/themes/styles.css" @@ -258,9 +256,8 @@ def _compile_root_stylesheet(stylesheets: list[str]) -> str: stylesheet_full_path = assets_app_path / stylesheet.strip("/") if not stylesheet_full_path.exists(): - raise FileNotFoundError( - f"The stylesheet file {stylesheet_full_path} does not exist." - ) + msg = f"The stylesheet file {stylesheet_full_path} does not exist." + raise FileNotFoundError(msg) if stylesheet_full_path.is_dir(): all_files = ( @@ -687,9 +684,8 @@ def into_component(component: Component | ComponentCallable) -> Component: if (converted := _into_component_once(component)) is not None: return converted if not callable(component): - raise TypeError( - f"Expected a Component or callable, got {component!r} of type {type(component)}" - ) + msg = f"Expected a Component or callable, got {component!r} of type {type(component)}" + raise TypeError(msg) try: component_called = component() @@ -716,9 +712,10 @@ def into_component(component: Component | ComponentCallable) -> Component: "Cannot pass a Var to a built-in function. Consider using .length() for accessing the length of an iterable Var." ).with_traceback(e.__traceback__) from None if message.endswith( - "indices must be integers or slices, not NumberCastedVar" - ) or message.endswith( - "indices must be integers or slices, not BooleanCastedVar" + ( + "indices must be integers or slices, not NumberCastedVar", + "indices must be integers or slices, not BooleanCastedVar", + ) ): raise TypeError( "Cannot index into a primitive sequence with a Var. Consider calling rx.Var.create() on the sequence." @@ -732,9 +729,8 @@ def into_component(component: Component | ComponentCallable) -> Component: if (converted := _into_component_once(component_called)) is not None: return converted - raise TypeError( - f"Expected a Component, got {component_called!r} of type {type(component_called)}" - ) + msg = f"Expected a Component, got {component_called!r} of type {type(component_called)}" + raise TypeError(msg) def compile_unevaluated_page( @@ -889,5 +885,6 @@ def compile_theme(cls, style: ComponentStyle | None) -> tuple[str, str]: ValueError: If the style is not set. """ if style is None: - raise ValueError("STYLE should be set") + msg = "STYLE should be set" + raise ValueError(msg) return compile_theme(style) diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 1766090135d..edfd25c7cdc 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -60,7 +60,8 @@ def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]: # Check for default imports. defaults = {field for field in fields_set if field.is_default} if len(defaults) >= 2: - raise ValueError("Only one default import is allowed.") + msg = "Only one default import is allowed." + raise ValueError(msg) # Get the default import, and the specific imports. default = next(iter({field.name for field in defaults}), "") @@ -91,9 +92,8 @@ def validate_imports(import_dict: ParsedImportDict): ): used_tags[import_name] = lib if lib[0] == "$" else already_imported continue - raise ValueError( - f"Can not compile, the tag {import_name} is used multiple time from {lib} and {used_tags[import_name]}" - ) + msg = f"Can not compile, the tag {import_name} is used multiple time from {lib} and {used_tags[import_name]}" + raise ValueError(msg) if import_name is not None: used_tags[import_name] = lib @@ -130,9 +130,11 @@ def compile_imports(import_dict: ParsedImportDict) -> list[dict]: for path, (default, rest) in compiled.items(): if not lib: if default: - raise ValueError("No default field allowed for empty library.") + msg = "No default field allowed for empty library." + raise ValueError(msg) if rest is None or len(rest) == 0: - raise ValueError("No fields to import.") + msg = "No fields to import." + raise ValueError(msg) import_dicts.extend(get_import_dict(module) for module in sorted(rest)) continue diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index f4c10f0b43b..94244fd004a 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -43,9 +43,8 @@ def validate_str(value: str): f"Output includes {value!s} which will be displayed as a string. If you are calling `str` on a Var, consider using .to_string() instead." ) elif perf_mode == PerformanceMode.RAISE: - raise ValueError( - f"Output includes {value!s} which will be displayed as a string. If you are calling `str` on a Var, consider using .to_string() instead." - ) + msg = f"Output includes {value!s} which will be displayed as a string. If you are calling `str` on a Var, consider using .to_string() instead." + raise ValueError(msg) def _components_from_var(var: Var) -> Sequence[BaseComponent]: @@ -72,10 +71,9 @@ def create(cls, contents: Any) -> Component: if isinstance(contents, LiteralStringVar): validate_str(contents._var_value) return cls._unsafe_create(children=[], contents=contents) - else: - if isinstance(contents, str): - validate_str(contents) - contents = Var.create(contents if contents is not None else "") + if isinstance(contents, str): + validate_str(contents) + contents = Var.create(contents if contents is not None else "") return cls._unsafe_create(children=[], contents=contents) diff --git a/reflex/components/base/meta.py b/reflex/components/base/meta.py index 9f61b81ce53..ca7daee0962 100644 --- a/reflex/components/base/meta.py +++ b/reflex/components/base/meta.py @@ -20,7 +20,8 @@ def render(self) -> dict: """ # Make sure the title is a single string. if len(self.children) != 1 or not isinstance(self.children[0], Bare): - raise ValueError("Title must be a single string.") + msg = "Title must be a single string." + raise ValueError(msg) return super().render() diff --git a/reflex/components/base/script.py b/reflex/components/base/script.py index 15145ecbfeb..9bca2fbb78f 100644 --- a/reflex/components/base/script.py +++ b/reflex/components/base/script.py @@ -68,7 +68,8 @@ def create(cls, *children, **props) -> Component: ValueError: when neither children nor `src` are specified. """ if not children and not props.get("src"): - raise ValueError("Must provide inline script or `src` prop.") + msg = "Must provide inline script or `src` prop." + raise ValueError(msg) return super().create(*children, **props) diff --git a/reflex/components/component.py b/reflex/components/component.py index 6a9c320dc69..6f5721afd1b 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -157,7 +157,8 @@ def default_value(self) -> FIELD_TYPE: return self.default if self.default_factory is not None: return self.default_factory() - raise ValueError("No default value or factory provided.") + msg = "No default value or factory provided." + raise ValueError(msg) def __repr__(self) -> str: """Represent the field in a readable format. @@ -194,7 +195,8 @@ def field( ValueError: If both default and default_factory are specified. """ if default is not MISSING and default_factory is not None: - raise ValueError("cannot specify both default and default_factory") + msg = "cannot specify both default and default_factory" + raise ValueError(msg) return ComponentField( # pyright: ignore [reportReturnType] default=default, default_factory=default_factory, @@ -770,11 +772,12 @@ def _post_init(self, *args, **kwargs): and key not in component_specific_triggers and key not in props ): - raise ValueError( + msg = ( f"The {(comp_name := type(self).__name__)} does not take in an `{key}` event trigger. If {comp_name}" f" is a third party component make sure to add `{key}` to the component's event triggers. " f"visit https://reflex.dev/docs/wrapping-react/guide/#event-triggers for more info." ) + raise ValueError(msg) if key in component_specific_triggers: # Event triggers are bound to event chains. is_var = False @@ -845,7 +848,8 @@ def _post_init(self, *args, **kwargs): style = kwargs.get("style", {}) if isinstance(style, Sequence): if any(not isinstance(s, Mapping) for s in style): - raise TypeError("Style must be a dictionary or a list of dictionaries.") + msg = "Style must be a dictionary or a list of dictionaries." + raise TypeError(msg) # Merge styles, the later ones overriding keys in the earlier ones. style = { k: v @@ -880,14 +884,12 @@ def _post_init(self, *args, **kwargs): if not isinstance(c, StringVar) and not issubclass( c._var_type, str ): - raise TypeError( - f"Invalid class_name passed for prop {type(self).__name__}.class_name, expected type str, got value {c._js_expr} of type {c._var_type}." - ) + msg = f"Invalid class_name passed for prop {type(self).__name__}.class_name, expected type str, got value {c._js_expr} of type {c._var_type}." + raise TypeError(msg) has_var = True else: - raise TypeError( - f"Invalid class_name passed for prop {type(self).__name__}.class_name, expected type str, got value {c} of type {type(c)}." - ) + msg = f"Invalid class_name passed for prop {type(self).__name__}.class_name, expected type str, got value {c} of type {type(c)}." + raise TypeError(msg) if has_var: kwargs["class_name"] = LiteralArrayVar.create( class_name, _var_type=list[str] @@ -899,9 +901,8 @@ def _post_init(self, *args, **kwargs): and not isinstance(class_name, StringVar) and not issubclass(class_name._var_type, str) ): - raise TypeError( - f"Invalid class_name passed for prop {type(self).__name__}.class_name, expected type str, got value {class_name._js_expr} of type {class_name._var_type}." - ) + msg = f"Invalid class_name passed for prop {type(self).__name__}.class_name, expected type str, got value {class_name._js_expr} of type {class_name._var_type}." + raise TypeError(msg) # Construct the component. for key, value in kwargs.items(): setattr(self, key, value) @@ -1228,9 +1229,8 @@ def _add_style_recursive( """ # 1. Default style from `_add_style`/`add_style`. if type(self)._add_style != Component._add_style: - raise UserWarning( - "Do not override _add_style directly. Use add_style instead." - ) + msg = "Do not override _add_style directly. Use add_style instead." + raise UserWarning(msg) new_style = self._add_style() style_vars = [new_style._var_data] @@ -1348,9 +1348,8 @@ def validate_child(child: Any): validate_child(child.default) if self._invalid_children and child_name in self._invalid_children: - raise ValueError( - f"The component `{comp_name}` cannot have `{child_name}` as a child component" - ) + msg = f"The component `{comp_name}` cannot have `{child_name}` as a child component" + raise ValueError(msg) if self._valid_children and child_name not in [ *self._valid_children, @@ -1359,9 +1358,8 @@ def validate_child(child: Any): valid_child_list = ", ".join( [f"`{v_child}`" for v_child in self._valid_children] ) - raise ValueError( - f"The component `{comp_name}` only allows the components: {valid_child_list} as children. Got `{child_name}` instead." - ) + msg = f"The component `{comp_name}` only allows the components: {valid_child_list} as children. Got `{child_name}` instead." + raise ValueError(msg) if child._valid_parents and all( clz_name not in [*child._valid_parents, *allowed_components] @@ -1370,9 +1368,8 @@ def validate_child(child: Any): valid_parent_list = ", ".join( [f"`{v_parent}`" for v_parent in child._valid_parents] ) - raise ValueError( - f"The component `{child_name}` can only be a child of the components: {valid_parent_list}. Got `{comp_name}` instead." - ) + msg = f"The component `{child_name}` can only be a child of the components: {valid_parent_list}. Got `{comp_name}` instead." + raise ValueError(msg) for child in children: validate_child(child) @@ -1503,13 +1500,9 @@ def _has_stateful_event_triggers(self): """ if self.event_triggers and self._event_trigger_values_use_state(): return True - else: - for child in self.children: - if ( - isinstance(child, Component) - and child._has_stateful_event_triggers() - ): - return True + for child in self.children: + if isinstance(child, Component) and child._has_stateful_event_triggers(): + return True return False @classmethod @@ -1768,6 +1761,7 @@ def _get_mount_lifecycle_hook(self) -> str | None: {on_unmount or ""} }} }}, []);""" + return None def _get_ref_hook(self) -> Var | None: """Generate the ref hook for the component. @@ -1781,6 +1775,7 @@ def _get_ref_hook(self) -> Var | None: f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};", _var_data=VarData(position=Hooks.HookPosition.INTERNAL), ) + return None def _get_vars_hooks(self) -> dict[str, VarData | None]: """Get the hooks required by vars referenced in this component. @@ -2223,7 +2218,7 @@ def _register_custom_component( _var_type=unwrap_var_annotation(annotation), ).guess_type() if not types.safe_issubclass(annotation, EventHandler) - else EventSpec(handler=EventHandler(fn=lambda: [])) + else EventSpec(handler=EventHandler(fn=no_args_event_spec)) ) for prop, annotation in typing.get_type_hints(component_fn).items() if prop != "return" @@ -2234,7 +2229,8 @@ def _register_custom_component( **dummy_props, ) if dummy_component.tag is None: - raise TypeError(f"Could not determine the tag name for {component_fn!r}") + msg = f"Could not determine the tag name for {component_fn!r}" + raise TypeError(msg) CUSTOM_COMPONENTS[dummy_component.tag] = dummy_component @@ -2313,7 +2309,8 @@ def _get_dynamic_imports(self) -> str: # extract the correct import name from library name base_import_name = self._get_import_name() if base_import_name is None: - raise ValueError("Undefined library for NoSSRComponent") + msg = "Undefined library for NoSSRComponent" + raise ValueError(msg) import_name = format.format_library_name(base_import_name) library_import = f"const {self.alias if self.alias else self.tag} = dynamic(() => import('{import_name}')" @@ -2321,7 +2318,7 @@ def _get_dynamic_imports(self) -> str: # https://nextjs.org/docs/pages/building-your-application/optimizing/lazy-loading#with-named-exports f".then((mod) => mod.{self.tag})" if not self.is_default else "" ) - return "".join((library_import, mod_import, opts_fragment)) + return library_import + mod_import + opts_fragment class StatefulComponent(BaseComponent): @@ -2541,7 +2538,7 @@ def _get_hook_deps(hook: str) -> list[str]: var_name = var_name.strip() # Break up array and object destructuring if used. - if var_name.startswith("[") or var_name.startswith("{"): + if var_name.startswith(("[", "{")): return [ v.strip().replace("...", "") for v in var_name.strip("[]{}").split(",") ] diff --git a/reflex/components/core/breakpoints.py b/reflex/components/core/breakpoints.py index 9a80ad69db2..f203621aa6b 100644 --- a/reflex/components/core/breakpoints.py +++ b/reflex/components/core/breakpoints.py @@ -75,19 +75,17 @@ def create( if custom is not None: if any(threshold is not None for threshold in thresholds): - raise ValueError("Named props cannot be used with custom thresholds") + msg = "Named props cannot be used with custom thresholds" + raise ValueError(msg) return Breakpoints(custom) - else: - return Breakpoints( - { - k: v - for k, v in zip( - ["initial", *breakpoint_names], thresholds, strict=True - ) - if v is not None - } - ) + return Breakpoints( + { + k: v + for k, v in zip(["initial", *breakpoint_names], thresholds, strict=True) + if v is not None + } + ) breakpoints = Breakpoints.create diff --git a/reflex/components/core/colors.py b/reflex/components/core/colors.py index c1ec35e7c34..5556559eedd 100644 --- a/reflex/components/core/colors.py +++ b/reflex/components/core/colors.py @@ -32,19 +32,22 @@ def color( """ if isinstance(color, str): if color not in COLORS and REFLEX_VAR_OPENING_TAG not in color: - raise ValueError(f"Color must be one of {COLORS}, received {color}") + msg = f"Color must be one of {COLORS}, received {color}" + raise ValueError(msg) elif not isinstance(color, Var): - raise ValueError("Color must be a string or a Var") + msg = "Color must be a string or a Var" + raise ValueError(msg) if isinstance(shade, int): if shade < MIN_SHADE_VALUE or shade > MAX_SHADE_VALUE: - raise ValueError( - f"Shade must be between {MIN_SHADE_VALUE} and {MAX_SHADE_VALUE}" - ) + msg = f"Shade must be between {MIN_SHADE_VALUE} and {MAX_SHADE_VALUE}" + raise ValueError(msg) elif not isinstance(shade, Var): - raise ValueError("Shade must be an integer or a Var") + msg = "Shade must be an integer or a Var" + raise ValueError(msg) if not isinstance(alpha, (bool, Var)): - raise ValueError("Alpha must be a boolean or a Var") + msg = "Alpha must be a boolean or a Var" + raise ValueError(msg) return Color(color, shade, alpha) diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index 8d14a7c3c29..28ec1224dc3 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -132,7 +132,8 @@ def cond(condition: Any, c1: Any, c2: Any = types.Unset(), /) -> Component | Var # Convert the condition to a Var. cond_var = LiteralVar.create(condition) if cond_var is None: - raise ValueError("The condition must be set.") + msg = "The condition must be set." + raise ValueError(msg) # If the first component is a component, create a Cond component. if isinstance(c1, BaseComponent): @@ -145,7 +146,8 @@ def cond(condition: Any, c1: Any, c2: Any = types.Unset(), /) -> Component | Var if isinstance(c2, BaseComponent): return Cond.create(cond_var.bool(), Fragment.create(c1), c2) if isinstance(c2, types.Unset): - raise ValueError("For conditional vars, the second argument must be set.") + msg = "For conditional vars, the second argument must be set." + raise ValueError(msg) # convert the truth and false cond parts into vars so the _var_data can be obtained. c1_var = Var.create(c1) diff --git a/reflex/components/core/debounce.py b/reflex/components/core/debounce.py index 54e5ccb2909..b0435885066 100644 --- a/reflex/components/core/debounce.py +++ b/reflex/components/core/debounce.py @@ -70,14 +70,16 @@ def create(cls, *children: Component, **props: Any) -> Component: ValueError: if the child element does not have an on_change handler. """ if len(children) != 1: - raise RuntimeError( + msg = ( "Provide a single child for DebounceInput, such as rx.input() or " - "rx.text_area()", + "rx.text_area()" ) + raise RuntimeError(msg) child = children[0] if "on_change" not in child.event_triggers: - raise ValueError("DebounceInput child requires an on_change handler") + msg = "DebounceInput child requires an on_change handler" + raise ValueError(msg) # Carry known props and event_triggers from the child. props_from_child = { diff --git a/reflex/components/core/foreach.py b/reflex/components/core/foreach.py index 70b190f69b5..d6537a22e4c 100644 --- a/reflex/components/core/foreach.py +++ b/reflex/components/core/foreach.py @@ -69,19 +69,19 @@ def create( ) if iterable._var_type == Any: - raise ForeachVarError( + msg = ( f"Could not foreach over var `{iterable!s}` of type Any. " "(If you are trying to foreach over a state var, add a type annotation to the var). " "See https://reflex.dev/docs/library/dynamic-rendering/foreach/" ) + raise ForeachVarError(msg) if ( hasattr(render_fn, "__qualname__") and render_fn.__qualname__ == ComponentState.create.__qualname__ ): - raise TypeError( - "Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet." - ) + msg = "Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet." + raise TypeError(msg) if isinstance(iterable, ObjectVar): iterable = iterable.entries() @@ -90,10 +90,11 @@ def create( iterable = iterable.split() if not isinstance(iterable, ArrayVar): - raise ForeachVarError( + msg = ( f"Could not foreach over var `{iterable!s}` of type {iterable._var_type}. " "See https://reflex.dev/docs/library/dynamic-rendering/foreach/" ) + raise ForeachVarError(msg) if types.is_optional(iterable._var_type): iterable = cond(iterable, iterable, []) @@ -122,11 +123,12 @@ def _render(self) -> IterTag: # Validate the render function signature. if len(params) == 0 or len(params) > 2: - raise ForeachRenderError( + msg = ( "Expected 1 or 2 parameters in foreach render function, got " f"{[p.name for p in params]}. See " "https://reflex.dev/docs/library/dynamic-rendering/foreach/" ) + raise ForeachRenderError(msg) if len(params) >= 1: # Determine the arg var name based on the params accepted by render_fn. diff --git a/reflex/components/core/html.py b/reflex/components/core/html.py index 60d764c14c0..77f36af5265 100644 --- a/reflex/components/core/html.py +++ b/reflex/components/core/html.py @@ -30,9 +30,9 @@ def create(cls, *children, **props): """ # If children are not provided, throw an error. if len(children) != 1: - raise ValueError("Must provide children to the html component.") - else: - props["dangerouslySetInnerHTML"] = {"__html": children[0]} + msg = "Must provide children to the html component." + raise ValueError(msg) + props["dangerouslySetInnerHTML"] = {"__html": children[0]} # Apply the default classname given_class_name = props.pop("class_name", []) diff --git a/reflex/components/core/match.py b/reflex/components/core/match.py index 87b8bd54d8b..f822832ebcc 100644 --- a/reflex/components/core/match.py +++ b/reflex/components/core/match.py @@ -47,9 +47,8 @@ def create(cls, cond: Any, *cases) -> Component | Var: cls._validate_return_types(match_cases) if default is None and isinstance(match_cases[0][-1], Var): - raise ValueError( - "For cases with return types as Vars, a default case must be provided" - ) + msg = "For cases with return types as Vars, a default case must be provided" + raise ValueError(msg) return cls._create_match_cond_var_or_component( match_cond_var, match_cases, default @@ -71,7 +70,8 @@ def _create_condition_var(cls, cond: Any) -> Var: match_cond_var = LiteralVar.create(cond) if match_cond_var is None: - raise ValueError("The condition must be set") + msg = "The condition must be set" + raise ValueError(msg) return match_cond_var @classmethod @@ -90,10 +90,12 @@ def _process_cases(cls, cases: list) -> tuple[list, Var | BaseComponent | None]: default = None if len([case for case in cases if not isinstance(case, tuple)]) > 1: - raise ValueError("rx.match can only have one default case.") + msg = "rx.match can only have one default case." + raise ValueError(msg) if not cases: - raise ValueError("rx.match should have at least one case.") + msg = "rx.match should have at least one case." + raise ValueError(msg) # Get the default case which should be the last non-tuple arg if not isinstance(cases[-1], tuple): @@ -119,8 +121,7 @@ def _create_case_var_with_var_data(cls, case_element: Any) -> Var: The case element Var. """ _var_data = case_element._var_data if isinstance(case_element, Style) else None - case_element = LiteralVar.create(case_element, _var_data=_var_data) - return case_element + return LiteralVar.create(case_element, _var_data=_var_data) @classmethod def _process_match_cases(cls, cases: list) -> list[list[Var]]: @@ -138,14 +139,12 @@ def _process_match_cases(cls, cases: list) -> list[list[Var]]: match_cases = [] for case in cases: if not isinstance(case, tuple): - raise ValueError( - "rx.match should have tuples of cases and a default case as the last argument." - ) + msg = "rx.match should have tuples of cases and a default case as the last argument." + raise ValueError(msg) # There should be at least two elements in a case tuple(a condition and return value) if len(case) < 2: - raise ValueError( - "A case tuple should have at least a match case element and a return value." - ) + msg = "A case tuple should have at least a match case element and a return value." + raise ValueError(msg) case_list = [] for element in case: @@ -156,7 +155,8 @@ def _process_match_cases(cls, cases: list) -> list[list[Var]]: else element ) if not isinstance(el, (Var, BaseComponent)): - raise ValueError("Case element must be a var or component") + msg = "Case element must be a var or component" + raise ValueError(msg) case_list.append(el) match_cases.append(case_list) @@ -183,11 +183,12 @@ def _validate_return_types(cls, match_cases: list[list[Var]]) -> None: for index, case in enumerate(match_cases): if not isinstance(case[-1], return_type): - raise MatchTypeError( + msg = ( f"Match cases should have the same return types. Case {index} with return " f"value `{case[-1]._js_expr if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`" f" of type {type(case[-1])!r} is not {return_type}" ) + raise MatchTypeError(msg) @classmethod def _create_match_cond_var_or_component( @@ -226,7 +227,8 @@ def _create_match_cond_var_or_component( if any( case for case in match_cases if not isinstance(case[-1], Var) ) or not isinstance(default, Var): - raise ValueError("Return types of match cases should be Vars.") + msg = "Return types of match cases should be Vars." + raise ValueError(msg) return Var( _js_expr=format.format_match( diff --git a/reflex/components/datadisplay/code.py b/reflex/components/datadisplay/code.py index 175789bb166..2fc72200b6b 100644 --- a/reflex/components/datadisplay/code.py +++ b/reflex/components/datadisplay/code.py @@ -482,8 +482,7 @@ def create( if copy_button: return Box.create(code_block, copy_button, position="relative") - else: - return code_block + return code_block def add_style(self): """Add style to the component.""" diff --git a/reflex/components/datadisplay/dataeditor.py b/reflex/components/datadisplay/dataeditor.py index 0b968634358..6b7ba25186e 100644 --- a/reflex/components/datadisplay/dataeditor.py +++ b/reflex/components/datadisplay/dataeditor.py @@ -383,9 +383,8 @@ def create(cls, *children, **props) -> Component: # If rows is not provided, determine from data. if rows is None: if isinstance(data, Var) and not isinstance(data, ArrayVar): - raise ValueError( - "DataEditor data must be an ArrayVar if rows is not provided." - ) + msg = "DataEditor data must be an ArrayVar if rows is not provided." + raise ValueError(msg) props["rows"] = data.length() if isinstance(data, ArrayVar) else len(data) @@ -393,13 +392,11 @@ def create(cls, *children, **props) -> Component: if types.is_dataframe(type(data)) or ( isinstance(data, Var) and types.is_dataframe(data._var_type) ): - raise ValueError( - "Cannot pass in both a pandas dataframe and columns to the data_editor component." - ) - else: - props["columns"] = [ - format.format_data_editor_column(col) for col in columns - ] + msg = "Cannot pass in both a pandas dataframe and columns to the data_editor component." + raise ValueError(msg) + props["columns"] = [ + format.format_data_editor_column(col) for col in columns + ] if "theme" in props: theme = props.get("theme") diff --git a/reflex/components/datadisplay/logo.py b/reflex/components/datadisplay/logo.py index 444686a1610..235ed91c2e2 100644 --- a/reflex/components/datadisplay/logo.py +++ b/reflex/components/datadisplay/logo.py @@ -2,11 +2,10 @@ import reflex as rx +SVG_COLOR = rx.color_mode_cond("#110F1F", "white") -def svg_logo( - color: str | rx.Var[str] = rx.color_mode_cond("#110F1F", "white"), - **props, -): + +def svg_logo(color: str | rx.Var[str] = SVG_COLOR, **props): """A Reflex logo SVG. Args: diff --git a/reflex/components/datadisplay/shiki_code_block.py b/reflex/components/datadisplay/shiki_code_block.py index bcd348cce80..1f14901e0f1 100644 --- a/reflex/components/datadisplay/shiki_code_block.py +++ b/reflex/components/datadisplay/shiki_code_block.py @@ -636,9 +636,8 @@ def add_imports(self) -> dict[str, list[str]]: """ imports = defaultdict(list) if not isinstance(self.transformers, LiteralVar): - raise ValueError( - f"transformers should be a LiteralVar type. Got {type(self.transformers)} instead." - ) + msg = f"transformers should be a LiteralVar type. Got {type(self.transformers)} instead." + raise ValueError(msg) for transformer in self.transformers._var_value: if isinstance(transformer, ShikiBaseTransformers): imports[transformer.library].extend( @@ -663,9 +662,8 @@ def create_transformer(cls, library: str, fns: list[str]) -> ShikiBaseTransforme ValueError: If a supplied function name is not valid str. """ if any(not isinstance(fn_name, str) for fn_name in fns): - raise ValueError( - f"the function names should be str names of functions in the specified transformer: {library!r}" - ) + msg = f"the function names should be str names of functions in the specified transformer: {library!r}" + raise ValueError(msg) return ShikiBaseTransformers( library=library, fns=[FunctionStringVar.create(fn) for fn in fns], # pyright: ignore [reportCallIssue] @@ -811,8 +809,7 @@ def create( return ShikiCodeBlock.create( children[0], copy_button, position="relative", **props ) - else: - return ShikiCodeBlock.create(children[0], **props) + return ShikiCodeBlock.create(children[0], **props) @staticmethod def _map_themes(theme: str) -> str: @@ -829,9 +826,8 @@ def _map_languages(language: str) -> str: @staticmethod def _strip_transformer_triggers(code: str | StringVar) -> StringVar | str: if not isinstance(code, (StringVar, str)): - raise VarTypeError( - f"code should be string literal or a StringVar type. Got {type(code)} instead." - ) + msg = f"code should be string literal or a StringVar type. Got {type(code)} instead." + raise VarTypeError(msg) regex_pattern = r"[\/#]+ *\[!code.*?\]" if isinstance(code, Var): @@ -840,6 +836,7 @@ def _strip_transformer_triggers(code: str | StringVar) -> StringVar | str: ) if isinstance(code, str): return re.sub(regex_pattern, "", code) + return None class TransformerNamespace(ComponentNamespace): diff --git a/reflex/components/dynamic.py b/reflex/components/dynamic.py index b2498b8418d..5e2d9a6680b 100644 --- a/reflex/components/dynamic.py +++ b/reflex/components/dynamic.py @@ -50,9 +50,8 @@ def bundle_library(component: Union["Component", str]): bundled_libraries.add(component) return if component.library is None: - raise DynamicComponentMissingLibraryError( - "Component must have a library to bundle." - ) + msg = "Component must have a library to bundle." + raise DynamicComponentMissingLibraryError(msg) bundled_libraries.add(format_library_name(component.library)) diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index 36bae27af6f..6b76d285c41 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -746,9 +746,8 @@ def create(cls, *children, **props): if enter_key_submit is not None: enter_key_submit = Var.create(enter_key_submit) if "on_key_down" in props: - raise ValueError( - "Cannot combine `enter_key_submit` with `on_key_down`.", - ) + msg = "Cannot combine `enter_key_submit` with `on_key_down`." + raise ValueError(msg) custom_attrs["on_key_down"] = Var( _js_expr=f"(e) => enterKeySubmitOnKeyDown(e, {enter_key_submit!s})", _var_data=VarData.merge(enter_key_submit._get_all_var_data()), diff --git a/reflex/components/el/elements/metadata.py b/reflex/components/el/elements/metadata.py index bc961ce14da..e49d73d5515 100644 --- a/reflex/components/el/elements/metadata.py +++ b/reflex/components/el/elements/metadata.py @@ -13,7 +13,6 @@ class Base(BaseHTML): tag = "base" - tag = "base" href: Var[str] target: Var[str] diff --git a/reflex/components/gridjs/datatable.py b/reflex/components/gridjs/datatable.py index 38a05dffc75..8472c45b9d7 100644 --- a/reflex/components/gridjs/datatable.py +++ b/reflex/components/gridjs/datatable.py @@ -67,36 +67,32 @@ def create(cls, *children, **props): # The annotation should be provided if data is a computed var. We need this to know how to # render pandas dataframes. if is_computed_var(data) and data._var_type == Any: - raise ValueError( - "Annotation of the computed var assigned to the data field should be provided." - ) + msg = "Annotation of the computed var assigned to the data field should be provided." + raise ValueError(msg) if ( columns is not None and is_computed_var(columns) and columns._var_type == Any ): - raise ValueError( - "Annotation of the computed var assigned to the column field should be provided." - ) + msg = "Annotation of the computed var assigned to the column field should be provided." + raise ValueError(msg) # If data is a pandas dataframe and columns are provided throw an error. if ( types.is_dataframe(type(data)) or (isinstance(data, Var) and types.is_dataframe(data._var_type)) ) and columns is not None: - raise ValueError( - "Cannot pass in both a pandas dataframe and columns to the data_table component." - ) + msg = "Cannot pass in both a pandas dataframe and columns to the data_table component." + raise ValueError(msg) # If data is a list and columns are not provided, throw an error if ( (isinstance(data, Var) and types.typehint_issubclass(data._var_type, list)) or isinstance(data, list) ) and columns is None: - raise ValueError( - "column field should be specified when the data field is a list type" - ) + msg = "column field should be specified when the data field is a list type" + raise ValueError(msg) # Create the component. return super().create( @@ -126,7 +122,8 @@ def _render(self) -> Tag: # If given a pandas df break up the data and columns data = serialize(self.data) if not isinstance(data, dict): - raise ValueError("Serialized dataframe should be a dict.") + msg = "Serialized dataframe should be a dict." + raise ValueError(msg) self.columns = LiteralVar.create(data["columns"]) self.data = LiteralVar.create(data["data"]) diff --git a/reflex/components/lucide/icon.py b/reflex/components/lucide/icon.py index 9550e4e44c5..56ab9a4f924 100644 --- a/reflex/components/lucide/icon.py +++ b/reflex/components/lucide/icon.py @@ -42,27 +42,28 @@ def create(cls, *children, **props) -> Component: if len(children) == 1: child = Var.create(children[0]).guess_type() if not isinstance(child, StringVar): - raise AttributeError( - f"Icon name must be a string, got {children[0]._var_type if isinstance(children[0], Var) else children[0]}" - ) + msg = f"Icon name must be a string, got {children[0]._var_type if isinstance(children[0], Var) else children[0]}" + raise AttributeError(msg) props["tag"] = children[0] else: - raise AttributeError( - f"Passing multiple children to Icon component is not allowed: remove positional arguments {children[1:]} to fix" - ) + msg = f"Passing multiple children to Icon component is not allowed: remove positional arguments {children[1:]} to fix" + raise AttributeError(msg) if "tag" not in props: - raise AttributeError("Missing 'tag' keyword-argument for Icon") + msg = "Missing 'tag' keyword-argument for Icon" + raise AttributeError(msg) tag_var: Var | LiteralVar = Var.create(props.pop("tag")) if isinstance(tag_var, LiteralVar): if isinstance(tag_var, LiteralStringVar): tag = format.to_snake_case(tag_var._var_value.lower()) else: - raise TypeError(f"Icon name must be a string, got {type(tag_var)}") + msg = f"Icon name must be a string, got {type(tag_var)}" + raise TypeError(msg) elif isinstance(tag_var, Var): tag_stringified = tag_var.guess_type() if not isinstance(tag_stringified, StringVar): - raise TypeError(f"Icon name must be a string, got {tag_var._var_type}") + msg = f"Icon name must be a string, got {tag_var._var_type}" + raise TypeError(msg) return DynamicIcon.create(name=tag_stringified.replace("_", "-"), **props) if tag not in LUCIDE_ICON_LIST: diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index 36c559516ae..60e469591b6 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -170,9 +170,8 @@ def create(cls, *children, **props) -> Component: The markdown component. """ if len(children) != 1 or not isinstance(children[0], (str, Var)): - raise ValueError( - "Markdown component must have exactly one child containing the markdown source." - ) + msg = "Markdown component must have exactly one child containing the markdown source." + raise ValueError(msg) # Update the base component map with the custom component map. component_map = {**get_base_component_map(), **props.pop("component_map", {})} @@ -319,7 +318,8 @@ def get_component(self, tag: str, **props) -> Component: """ # Check the tag is valid. if tag not in self.component_map: - raise ValueError(f"No markdown component found for tag: {tag}.") + msg = f"No markdown component found for tag: {tag}." + raise ValueError(msg) special_props = [_PROPS] children = [ @@ -342,10 +342,9 @@ def get_component(self, tag: str, **props) -> Component: if children_prop is not None: children = [] # Get the component. - component = self.component_map[tag](*children, **props).set( + return self.component_map[tag](*children, **props).set( special_props=special_props ) - return component def format_component(self, tag: str, **props) -> str: """Format a component for rendering in the component map. @@ -437,7 +436,7 @@ def _get_custom_code(self) -> str | None: """ def _render(self) -> Tag: - tag = ( + return ( super() ._render() .add_props( @@ -447,4 +446,3 @@ def _render(self) -> Tag: ) .remove_props("componentMap", "componentMapHash") ) - return tag diff --git a/reflex/components/next/base.py b/reflex/components/next/base.py index 1dd3db437c2..f16701d88a6 100644 --- a/reflex/components/next/base.py +++ b/reflex/components/next/base.py @@ -5,5 +5,3 @@ class NextComponent(Component): """A Component used as based for any NextJS component.""" - - ... diff --git a/reflex/components/props.py b/reflex/components/props.py index 55051048fe4..8cd88497188 100644 --- a/reflex/components/props.py +++ b/reflex/components/props.py @@ -66,9 +66,8 @@ def __init__(self, component_name: str | None = None, **kwargs): except ValidationError as e: invalid_fields = ", ".join([error["loc"][0] for error in e.errors()]) # pyright: ignore [reportCallIssue, reportArgumentType] supported_props_str = ", ".join(f'"{field}"' for field in self.get_fields()) - raise InvalidPropValueError( - f"Invalid prop(s) {invalid_fields} for {component_name!r}. Supported props are {supported_props_str}" - ) from None + msg = f"Invalid prop(s) {invalid_fields} for {component_name!r}. Supported props are {supported_props_str}" + raise InvalidPropValueError(msg) from None class Config: # pyright: ignore [reportIncompatibleVariableOverride] """Pydantic config.""" diff --git a/reflex/components/radix/primitives/base.py b/reflex/components/radix/primitives/base.py index dae9ce3839c..c5dfb938cc5 100644 --- a/reflex/components/radix/primitives/base.py +++ b/reflex/components/radix/primitives/base.py @@ -21,8 +21,6 @@ def _render(self) -> Tag: super() ._render() .add_props( - **{ - "class_name": f"{format.to_title_case(self.tag or '')} {self.class_name or ''}", - } + class_name=f"{format.to_title_case(self.tag or '')} {self.class_name or ''}" ) ) diff --git a/reflex/components/radix/primitives/form.py b/reflex/components/radix/primitives/form.py index a77d576df27..ef8669b6e63 100644 --- a/reflex/components/radix/primitives/form.py +++ b/reflex/components/radix/primitives/form.py @@ -100,14 +100,12 @@ def create(cls, *children, **props): The form control component. """ if len(children) > 1: - raise ValueError( - f"FormControl can only have at most one child, got {len(children)} children" - ) + msg = f"FormControl can only have at most one child, got {len(children)} children" + raise ValueError(msg) for child in children: if not isinstance(child, (TextFieldRoot, DebounceInput)): - raise TypeError( - "Only Radix TextFieldRoot and DebounceInput are allowed as children of FormControl" - ) + msg = "Only Radix TextFieldRoot and DebounceInput are allowed as children of FormControl" + raise TypeError(msg) return super().create(*children, **props) @@ -168,8 +166,6 @@ class FormSubmit(FormComponent): class Form(FormRoot): """The Form component.""" - pass - class FormNamespace(ComponentNamespace): """Form components.""" diff --git a/reflex/components/radix/themes/components/alert_dialog.py b/reflex/components/radix/themes/components/alert_dialog.py index bc5e2dc7ed4..947f4d3a12f 100644 --- a/reflex/components/radix/themes/components/alert_dialog.py +++ b/reflex/components/radix/themes/components/alert_dialog.py @@ -5,12 +5,14 @@ from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.radix.themes.base import ( + RadixThemesComponent, + RadixThemesTriggerComponent, +) from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spec from reflex.vars.base import Var -from ..base import RadixThemesComponent, RadixThemesTriggerComponent - LiteralContentSize = Literal["1", "2", "3", "4"] diff --git a/reflex/components/radix/themes/components/aspect_ratio.py b/reflex/components/radix/themes/components/aspect_ratio.py index d821bd52378..2fecd8d407e 100644 --- a/reflex/components/radix/themes/components/aspect_ratio.py +++ b/reflex/components/radix/themes/components/aspect_ratio.py @@ -1,9 +1,8 @@ """Interactive components provided by @radix-ui/themes.""" +from reflex.components.radix.themes.base import RadixThemesComponent from reflex.vars.base import Var -from ..base import RadixThemesComponent - class AspectRatio(RadixThemesComponent): """Displays content with a desired ratio.""" diff --git a/reflex/components/radix/themes/components/avatar.py b/reflex/components/radix/themes/components/avatar.py index 77a305e29ee..65ec20ca905 100644 --- a/reflex/components/radix/themes/components/avatar.py +++ b/reflex/components/radix/themes/components/avatar.py @@ -3,10 +3,13 @@ from typing import Literal from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import ( + LiteralAccentColor, + LiteralRadius, + RadixThemesComponent, +) from reflex.vars.base import Var -from ..base import LiteralAccentColor, LiteralRadius, RadixThemesComponent - LiteralSize = Literal["1", "2", "3", "4", "5", "6", "7", "8", "9"] diff --git a/reflex/components/radix/themes/components/badge.py b/reflex/components/radix/themes/components/badge.py index 389012bf08b..d1f4c943b6e 100644 --- a/reflex/components/radix/themes/components/badge.py +++ b/reflex/components/radix/themes/components/badge.py @@ -4,10 +4,13 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.radix.themes.base import ( + LiteralAccentColor, + LiteralRadius, + RadixThemesComponent, +) from reflex.vars.base import Var -from ..base import LiteralAccentColor, LiteralRadius, RadixThemesComponent - class Badge(elements.Span, RadixThemesComponent): """A stylized badge element.""" diff --git a/reflex/components/radix/themes/components/button.py b/reflex/components/radix/themes/components/button.py index cb44ee68400..25ab0041558 100644 --- a/reflex/components/radix/themes/components/button.py +++ b/reflex/components/radix/themes/components/button.py @@ -4,15 +4,14 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements -from reflex.vars.base import Var - -from ..base import ( +from reflex.components.radix.themes.base import ( LiteralAccentColor, LiteralRadius, LiteralVariant, RadixLoadingProp, RadixThemesComponent, ) +from reflex.vars.base import Var LiteralButtonSize = Literal["1", "2", "3", "4"] diff --git a/reflex/components/radix/themes/components/callout.py b/reflex/components/radix/themes/components/callout.py index a75b421b6b4..cae6a11eff8 100644 --- a/reflex/components/radix/themes/components/callout.py +++ b/reflex/components/radix/themes/components/callout.py @@ -7,10 +7,9 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements from reflex.components.lucide.icon import Icon +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent - CalloutVariant = Literal["soft", "surface", "outline"] diff --git a/reflex/components/radix/themes/components/card.py b/reflex/components/radix/themes/components/card.py index e99ea9cef16..708bc0fd8c2 100644 --- a/reflex/components/radix/themes/components/card.py +++ b/reflex/components/radix/themes/components/card.py @@ -4,10 +4,9 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.radix.themes.base import RadixThemesComponent from reflex.vars.base import Var -from ..base import RadixThemesComponent - class Card(elements.Div, RadixThemesComponent): """Container that groups related content and actions.""" diff --git a/reflex/components/radix/themes/components/checkbox.py b/reflex/components/radix/themes/components/checkbox.py index 42277bfea60..4ebf49ed93a 100644 --- a/reflex/components/radix/themes/components/checkbox.py +++ b/reflex/components/radix/themes/components/checkbox.py @@ -4,12 +4,15 @@ from reflex.components.component import Component, ComponentNamespace from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import ( + LiteralAccentColor, + LiteralSpacing, + RadixThemesComponent, +) from reflex.components.radix.themes.layout.flex import Flex from reflex.components.radix.themes.typography.text import Text from reflex.event import EventHandler, passthrough_event_spec -from reflex.vars.base import LiteralVar, Var - -from ..base import LiteralAccentColor, LiteralSpacing, RadixThemesComponent +from reflex.vars.base import Var LiteralCheckboxSize = Literal["1", "2", "3"] LiteralCheckboxVariant = Literal["classic", "surface", "soft"] @@ -111,7 +114,7 @@ class HighLevelCheckbox(RadixThemesComponent): on_change: EventHandler[passthrough_event_spec(bool)] @classmethod - def create(cls, text: Var[str] = LiteralVar.create(""), **props) -> Component: + def create(cls, text: Var[str] = Var.create(""), **props) -> Component: """Create a checkbox with a label. Args: diff --git a/reflex/components/radix/themes/components/checkbox_cards.py b/reflex/components/radix/themes/components/checkbox_cards.py index 6fd8a7f30c3..e86df97fe15 100644 --- a/reflex/components/radix/themes/components/checkbox_cards.py +++ b/reflex/components/radix/themes/components/checkbox_cards.py @@ -4,10 +4,9 @@ from typing import Literal from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent - class CheckboxCardsRoot(RadixThemesComponent): """Root element for a CheckboxCards component.""" diff --git a/reflex/components/radix/themes/components/checkbox_group.py b/reflex/components/radix/themes/components/checkbox_group.py index e5872e2f72f..cd3a1bc7a04 100644 --- a/reflex/components/radix/themes/components/checkbox_group.py +++ b/reflex/components/radix/themes/components/checkbox_group.py @@ -5,10 +5,9 @@ from typing import Literal from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent - class CheckboxGroupRoot(RadixThemesComponent): """Root element for a CheckboxGroup component.""" diff --git a/reflex/components/radix/themes/components/context_menu.py b/reflex/components/radix/themes/components/context_menu.py index 742957d1c83..f99edea3efa 100644 --- a/reflex/components/radix/themes/components/context_menu.py +++ b/reflex/components/radix/themes/components/context_menu.py @@ -4,11 +4,11 @@ from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spec from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent from .checkbox import Checkbox from .radio_group import HighLevelRadioGroup diff --git a/reflex/components/radix/themes/components/data_list.py b/reflex/components/radix/themes/components/data_list.py index 05d4af074ab..26f0232797c 100644 --- a/reflex/components/radix/themes/components/data_list.py +++ b/reflex/components/radix/themes/components/data_list.py @@ -4,10 +4,9 @@ from typing import Literal from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent - class DataListRoot(RadixThemesComponent): """Root element for a DataList component.""" diff --git a/reflex/components/radix/themes/components/dialog.py b/reflex/components/radix/themes/components/dialog.py index ce6e52cb5eb..4d4a62bc4ee 100644 --- a/reflex/components/radix/themes/components/dialog.py +++ b/reflex/components/radix/themes/components/dialog.py @@ -5,12 +5,14 @@ from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.radix.themes.base import ( + RadixThemesComponent, + RadixThemesTriggerComponent, +) from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spec from reflex.vars.base import Var -from ..base import RadixThemesComponent, RadixThemesTriggerComponent - class DialogRoot(RadixThemesComponent): """Root component for Dialog.""" diff --git a/reflex/components/radix/themes/components/dropdown_menu.py b/reflex/components/radix/themes/components/dropdown_menu.py index b23411ffeb3..946ab46485b 100644 --- a/reflex/components/radix/themes/components/dropdown_menu.py +++ b/reflex/components/radix/themes/components/dropdown_menu.py @@ -4,12 +4,15 @@ from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import ( + LiteralAccentColor, + RadixThemesComponent, + RadixThemesTriggerComponent, +) from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spec from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent, RadixThemesTriggerComponent - LiteralDirType = Literal["ltr", "rtl"] LiteralSizeType = Literal["1", "2"] diff --git a/reflex/components/radix/themes/components/hover_card.py b/reflex/components/radix/themes/components/hover_card.py index 03559083edd..110113427cb 100644 --- a/reflex/components/radix/themes/components/hover_card.py +++ b/reflex/components/radix/themes/components/hover_card.py @@ -5,12 +5,14 @@ from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.radix.themes.base import ( + RadixThemesComponent, + RadixThemesTriggerComponent, +) from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, passthrough_event_spec from reflex.vars.base import Var -from ..base import RadixThemesComponent, RadixThemesTriggerComponent - class HoverCardRoot(RadixThemesComponent): """For sighted users to preview content available behind a link.""" diff --git a/reflex/components/radix/themes/components/icon_button.py b/reflex/components/radix/themes/components/icon_button.py index 7d865365be7..0733da92ae1 100644 --- a/reflex/components/radix/themes/components/icon_button.py +++ b/reflex/components/radix/themes/components/icon_button.py @@ -9,16 +9,15 @@ from reflex.components.core.match import Match from reflex.components.el import elements from reflex.components.lucide import Icon -from reflex.style import Style -from reflex.vars.base import Var - -from ..base import ( +from reflex.components.radix.themes.base import ( LiteralAccentColor, LiteralRadius, LiteralVariant, RadixLoadingProp, RadixThemesComponent, ) +from reflex.style import Style +from reflex.vars.base import Var LiteralButtonSize = Literal["1", "2", "3", "4"] @@ -70,9 +69,8 @@ def create(cls, *children, **props) -> Component: ) ] else: - raise ValueError( - "IconButton requires a child icon. Pass a string as the first child or a rx.icon." - ) + msg = "IconButton requires a child icon. Pass a string as the first child or a rx.icon." + raise ValueError(msg) if "size" in props: if isinstance(props["size"], str): children[0].size = RADIX_TO_LUCIDE_SIZE[props["size"]] # pyright: ignore[reportAttributeAccessIssue] @@ -83,7 +81,8 @@ def create(cls, *children, **props) -> Component: 12, ) if not isinstance(size_map_var, Var): - raise ValueError(f"Match did not return a Var: {size_map_var}") + msg = f"Match did not return a Var: {size_map_var}" + raise ValueError(msg) children[0].size = size_map_var # pyright: ignore[reportAttributeAccessIssue] return super().create(*children, **props) diff --git a/reflex/components/radix/themes/components/inset.py b/reflex/components/radix/themes/components/inset.py index 8e7482de99c..df82ce65a02 100644 --- a/reflex/components/radix/themes/components/inset.py +++ b/reflex/components/radix/themes/components/inset.py @@ -4,10 +4,9 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.radix.themes.base import RadixThemesComponent from reflex.vars.base import Var -from ..base import RadixThemesComponent - LiteralButtonSize = Literal["1", "2", "3", "4"] diff --git a/reflex/components/radix/themes/components/popover.py b/reflex/components/radix/themes/components/popover.py index f783acf9e3d..650a8a5af8d 100644 --- a/reflex/components/radix/themes/components/popover.py +++ b/reflex/components/radix/themes/components/popover.py @@ -5,12 +5,14 @@ from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.radix.themes.base import ( + RadixThemesComponent, + RadixThemesTriggerComponent, +) from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spec from reflex.vars.base import Var -from ..base import RadixThemesComponent, RadixThemesTriggerComponent - class PopoverRoot(RadixThemesComponent): """Floating element for displaying rich content, triggered by a button.""" diff --git a/reflex/components/radix/themes/components/progress.py b/reflex/components/radix/themes/components/progress.py index e9fe168c6d5..a6e31ca8cec 100644 --- a/reflex/components/radix/themes/components/progress.py +++ b/reflex/components/radix/themes/components/progress.py @@ -4,11 +4,10 @@ from reflex.components.component import Component from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.style import Style from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent - class Progress(RadixThemesComponent): """A progress bar component.""" diff --git a/reflex/components/radix/themes/components/radio.py b/reflex/components/radix/themes/components/radio.py index fd24bb6b50c..c649c01c114 100644 --- a/reflex/components/radix/themes/components/radio.py +++ b/reflex/components/radix/themes/components/radio.py @@ -3,10 +3,9 @@ from typing import Literal from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent - class Radio(RadixThemesComponent): """A radio component.""" diff --git a/reflex/components/radix/themes/components/radio_cards.py b/reflex/components/radix/themes/components/radio_cards.py index dabf0ab4d4e..823d68917c7 100644 --- a/reflex/components/radix/themes/components/radio_cards.py +++ b/reflex/components/radix/themes/components/radio_cards.py @@ -4,11 +4,10 @@ from typing import ClassVar, Literal from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.event import EventHandler, passthrough_event_spec from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent - class RadioCardsRoot(RadixThemesComponent): """Root element for RadioCards component.""" diff --git a/reflex/components/radix/themes/components/radio_group.py b/reflex/components/radix/themes/components/radio_group.py index 9e8c2df9002..528bd025fc3 100644 --- a/reflex/components/radix/themes/components/radio_group.py +++ b/reflex/components/radix/themes/components/radio_group.py @@ -8,6 +8,11 @@ import reflex as rx from reflex.components.component import Component, ComponentNamespace from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import ( + LiteralAccentColor, + LiteralSpacing, + RadixThemesComponent, +) from reflex.components.radix.themes.layout.flex import Flex from reflex.components.radix.themes.typography.text import Text from reflex.event import EventHandler, passthrough_event_spec @@ -15,8 +20,6 @@ from reflex.vars.base import LiteralVar, Var from reflex.vars.sequence import StringVar -from ..base import LiteralAccentColor, LiteralSpacing, RadixThemesComponent - LiteralFlexDirection = Literal["row", "column", "row-reverse", "column-reverse"] @@ -145,9 +148,8 @@ def create( isinstance(items, Var) and not types._issubclass(items._var_type, list) ): items_type = type(items) if not isinstance(items, Var) else items._var_type - raise TypeError( - f"The radio group component takes in a list, got {items_type} instead" - ) + msg = f"The radio group component takes in a list, got {items_type} instead" + raise TypeError(msg) default_value = LiteralVar.create(default_value) diff --git a/reflex/components/radix/themes/components/scroll_area.py b/reflex/components/radix/themes/components/scroll_area.py index 516649e12ce..5eaff849104 100644 --- a/reflex/components/radix/themes/components/scroll_area.py +++ b/reflex/components/radix/themes/components/scroll_area.py @@ -2,10 +2,9 @@ from typing import Literal +from reflex.components.radix.themes.base import RadixThemesComponent from reflex.vars.base import Var -from ..base import RadixThemesComponent - class ScrollArea(RadixThemesComponent): """Custom styled, cross-browser scrollable area using native functionality.""" diff --git a/reflex/components/radix/themes/components/segmented_control.py b/reflex/components/radix/themes/components/segmented_control.py index 057c2f55519..cfcdaa3fd92 100644 --- a/reflex/components/radix/themes/components/segmented_control.py +++ b/reflex/components/radix/themes/components/segmented_control.py @@ -7,11 +7,10 @@ from typing import ClassVar, Literal from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.event import EventHandler from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent - def on_value_change( value: Var[str | list[str]], diff --git a/reflex/components/radix/themes/components/select.py b/reflex/components/radix/themes/components/select.py index 4d473e0dc51..fc5d6006499 100644 --- a/reflex/components/radix/themes/components/select.py +++ b/reflex/components/radix/themes/components/select.py @@ -6,12 +6,15 @@ import reflex as rx from reflex.components.component import Component, ComponentNamespace from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import ( + LiteralAccentColor, + LiteralRadius, + RadixThemesComponent, +) from reflex.constants.compiler import MemoizationMode from reflex.event import no_args_event_spec, passthrough_event_spec from reflex.vars.base import Var -from ..base import LiteralAccentColor, LiteralRadius, RadixThemesComponent - class SelectRoot(RadixThemesComponent): """Displays a list of options for the user to pick from, triggered by a button.""" diff --git a/reflex/components/radix/themes/components/separator.py b/reflex/components/radix/themes/components/separator.py index 9fc06807a1a..3d120a8525e 100644 --- a/reflex/components/radix/themes/components/separator.py +++ b/reflex/components/radix/themes/components/separator.py @@ -3,10 +3,9 @@ from typing import Literal from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.vars.base import LiteralVar, Var -from ..base import LiteralAccentColor, RadixThemesComponent - LiteralSeperatorSize = Literal["1", "2", "3", "4"] diff --git a/reflex/components/radix/themes/components/skeleton.py b/reflex/components/radix/themes/components/skeleton.py index 57eba6234ad..e9283407311 100644 --- a/reflex/components/radix/themes/components/skeleton.py +++ b/reflex/components/radix/themes/components/skeleton.py @@ -1,11 +1,10 @@ """Skeleton theme from Radix components.""" from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import RadixLoadingProp, RadixThemesComponent from reflex.constants.compiler import MemoizationMode from reflex.vars.base import Var -from ..base import RadixLoadingProp, RadixThemesComponent - class Skeleton(RadixLoadingProp, RadixThemesComponent): """Skeleton component.""" diff --git a/reflex/components/radix/themes/components/slider.py b/reflex/components/radix/themes/components/slider.py index b1a3383b42f..17e434238c0 100644 --- a/reflex/components/radix/themes/components/slider.py +++ b/reflex/components/radix/themes/components/slider.py @@ -7,12 +7,11 @@ from reflex.components.component import Component from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.event import EventHandler, passthrough_event_spec from reflex.utils.types import typehint_issubclass from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent - on_value_event_spec = ( passthrough_event_spec(list[int | float]), passthrough_event_spec(list[int]), diff --git a/reflex/components/radix/themes/components/spinner.py b/reflex/components/radix/themes/components/spinner.py index 620d248c4e0..e777b45a274 100644 --- a/reflex/components/radix/themes/components/spinner.py +++ b/reflex/components/radix/themes/components/spinner.py @@ -3,10 +3,9 @@ from typing import Literal from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import RadixLoadingProp, RadixThemesComponent from reflex.vars.base import Var -from ..base import RadixLoadingProp, RadixThemesComponent - LiteralSpinnerSize = Literal["1", "2", "3"] diff --git a/reflex/components/radix/themes/components/switch.py b/reflex/components/radix/themes/components/switch.py index 2af4f55bb41..db4fb64472a 100644 --- a/reflex/components/radix/themes/components/switch.py +++ b/reflex/components/radix/themes/components/switch.py @@ -3,11 +3,10 @@ from typing import Literal from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.event import EventHandler, passthrough_event_spec from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent - LiteralSwitchSize = Literal["1", "2", "3"] diff --git a/reflex/components/radix/themes/components/table.py b/reflex/components/radix/themes/components/table.py index 7f3ba7a3073..e9aaa2733c9 100644 --- a/reflex/components/radix/themes/components/table.py +++ b/reflex/components/radix/themes/components/table.py @@ -5,10 +5,9 @@ from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.radix.themes.base import CommonPaddingProps, RadixThemesComponent from reflex.vars.base import Var -from ..base import CommonPaddingProps, RadixThemesComponent - class TableRoot(elements.Table, RadixThemesComponent): """A semantic table for presenting tabular data.""" diff --git a/reflex/components/radix/themes/components/tabs.py b/reflex/components/radix/themes/components/tabs.py index 1a4bbbd3079..d430a0bdce7 100644 --- a/reflex/components/radix/themes/components/tabs.py +++ b/reflex/components/radix/themes/components/tabs.py @@ -7,12 +7,11 @@ from reflex.components.component import Component, ComponentNamespace from reflex.components.core.breakpoints import Responsive from reflex.components.core.colors import color +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, passthrough_event_spec from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent - vertical_orientation_css = "&[data-orientation='vertical']" diff --git a/reflex/components/radix/themes/components/text_area.py b/reflex/components/radix/themes/components/text_area.py index 0cab7459d97..4d0f61123c8 100644 --- a/reflex/components/radix/themes/components/text_area.py +++ b/reflex/components/radix/themes/components/text_area.py @@ -6,10 +6,13 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.core.debounce import DebounceInput from reflex.components.el import elements +from reflex.components.radix.themes.base import ( + LiteralAccentColor, + LiteralRadius, + RadixThemesComponent, +) from reflex.vars.base import Var -from ..base import LiteralAccentColor, LiteralRadius, RadixThemesComponent - LiteralTextAreaSize = Literal["1", "2", "3"] LiteralTextAreaResize = Literal["none", "vertical", "horizontal", "both"] diff --git a/reflex/components/radix/themes/components/text_field.py b/reflex/components/radix/themes/components/text_field.py index 7de977eb358..198597c2b87 100644 --- a/reflex/components/radix/themes/components/text_field.py +++ b/reflex/components/radix/themes/components/text_field.py @@ -8,13 +8,16 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.core.debounce import DebounceInput from reflex.components.el import elements +from reflex.components.radix.themes.base import ( + LiteralAccentColor, + LiteralRadius, + RadixThemesComponent, +) from reflex.event import EventHandler, input_event, key_event from reflex.utils.types import is_optional from reflex.vars.base import Var from reflex.vars.number import ternary_operation -from ..base import LiteralAccentColor, LiteralRadius, RadixThemesComponent - LiteralTextFieldSize = Literal["1", "2", "3"] LiteralTextFieldVariant = Literal["classic", "surface", "soft"] diff --git a/reflex/components/radix/themes/components/tooltip.py b/reflex/components/radix/themes/components/tooltip.py index 46dff65b495..ef7a75723c1 100644 --- a/reflex/components/radix/themes/components/tooltip.py +++ b/reflex/components/radix/themes/components/tooltip.py @@ -3,13 +3,12 @@ from typing import Literal from reflex.components.component import Component +from reflex.components.radix.themes.base import RadixThemesComponent from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spec from reflex.utils import format from reflex.vars.base import Var -from ..base import RadixThemesComponent - LiteralSideType = Literal[ "top", "right", diff --git a/reflex/components/radix/themes/layout/base.py b/reflex/components/radix/themes/layout/base.py index f31f6a72c8b..c67a4e51267 100644 --- a/reflex/components/radix/themes/layout/base.py +++ b/reflex/components/radix/themes/layout/base.py @@ -5,10 +5,13 @@ from typing import Literal from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import ( + CommonMarginProps, + CommonPaddingProps, + RadixThemesComponent, +) from reflex.vars.base import Var -from ..base import CommonMarginProps, CommonPaddingProps, RadixThemesComponent - LiteralBoolNumber = Literal["0", "1"] diff --git a/reflex/components/radix/themes/layout/box.py b/reflex/components/radix/themes/layout/box.py index a8ace5956d3..3932d2b9e1a 100644 --- a/reflex/components/radix/themes/layout/box.py +++ b/reflex/components/radix/themes/layout/box.py @@ -3,8 +3,7 @@ from __future__ import annotations from reflex.components.el import elements - -from ..base import RadixThemesComponent +from reflex.components.radix.themes.base import RadixThemesComponent class Box(elements.Div, RadixThemesComponent): diff --git a/reflex/components/radix/themes/layout/container.py b/reflex/components/radix/themes/layout/container.py index b1d2fbed385..5c484188b61 100644 --- a/reflex/components/radix/themes/layout/container.py +++ b/reflex/components/radix/themes/layout/container.py @@ -6,11 +6,10 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.radix.themes.base import RadixThemesComponent from reflex.style import STACK_CHILDREN_FULL_WIDTH from reflex.vars.base import LiteralVar, Var -from ..base import RadixThemesComponent - LiteralContainerSize = Literal["1", "2", "3", "4"] diff --git a/reflex/components/radix/themes/layout/flex.py b/reflex/components/radix/themes/layout/flex.py index 61e98ab6c86..0c359ca441d 100644 --- a/reflex/components/radix/themes/layout/flex.py +++ b/reflex/components/radix/themes/layout/flex.py @@ -6,10 +6,14 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.radix.themes.base import ( + LiteralAlign, + LiteralJustify, + LiteralSpacing, + RadixThemesComponent, +) from reflex.vars.base import Var -from ..base import LiteralAlign, LiteralJustify, LiteralSpacing, RadixThemesComponent - LiteralFlexDirection = Literal["row", "column", "row-reverse", "column-reverse"] LiteralFlexWrap = Literal["nowrap", "wrap", "wrap-reverse"] diff --git a/reflex/components/radix/themes/layout/grid.py b/reflex/components/radix/themes/layout/grid.py index 24e6d8d0653..2b5147a0789 100644 --- a/reflex/components/radix/themes/layout/grid.py +++ b/reflex/components/radix/themes/layout/grid.py @@ -6,10 +6,14 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.radix.themes.base import ( + LiteralAlign, + LiteralJustify, + LiteralSpacing, + RadixThemesComponent, +) from reflex.vars.base import Var -from ..base import LiteralAlign, LiteralJustify, LiteralSpacing, RadixThemesComponent - LiteralGridFlow = Literal["row", "column", "dense", "row-dense", "column-dense"] diff --git a/reflex/components/radix/themes/layout/list.py b/reflex/components/radix/themes/layout/list.py index 2cf460638e0..b6d0593fdb8 100644 --- a/reflex/components/radix/themes/layout/list.py +++ b/reflex/components/radix/themes/layout/list.py @@ -199,4 +199,5 @@ def __getattr__(name: Any): try: return globals()[name] except KeyError: - raise AttributeError(f"module '{__name__} has no attribute '{name}'") from None + msg = f"module '{__name__} has no attribute '{name}'" + raise AttributeError(msg) from None diff --git a/reflex/components/radix/themes/layout/section.py b/reflex/components/radix/themes/layout/section.py index 68a131751e8..920c876ada4 100644 --- a/reflex/components/radix/themes/layout/section.py +++ b/reflex/components/radix/themes/layout/section.py @@ -6,10 +6,9 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.radix.themes.base import RadixThemesComponent from reflex.vars.base import LiteralVar, Var -from ..base import RadixThemesComponent - LiteralSectionSize = Literal["1", "2", "3"] diff --git a/reflex/components/radix/themes/layout/stack.py b/reflex/components/radix/themes/layout/stack.py index e788b6273a3..660384c12d5 100644 --- a/reflex/components/radix/themes/layout/stack.py +++ b/reflex/components/radix/themes/layout/stack.py @@ -4,9 +4,9 @@ from reflex.components.component import Component from reflex.components.core.breakpoints import Responsive +from reflex.components.radix.themes.base import LiteralAlign, LiteralSpacing from reflex.vars.base import Var -from ..base import LiteralAlign, LiteralSpacing from .flex import Flex, LiteralFlexDirection diff --git a/reflex/components/radix/themes/typography/blockquote.py b/reflex/components/radix/themes/typography/blockquote.py index e32172e005e..cc772050651 100644 --- a/reflex/components/radix/themes/typography/blockquote.py +++ b/reflex/components/radix/themes/typography/blockquote.py @@ -7,9 +7,9 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent from .base import LiteralTextSize, LiteralTextWeight diff --git a/reflex/components/radix/themes/typography/code.py b/reflex/components/radix/themes/typography/code.py index ab610b50535..2da681393f6 100644 --- a/reflex/components/radix/themes/typography/code.py +++ b/reflex/components/radix/themes/typography/code.py @@ -8,9 +8,13 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements from reflex.components.markdown.markdown import MarkdownComponentMap +from reflex.components.radix.themes.base import ( + LiteralAccentColor, + LiteralVariant, + RadixThemesComponent, +) from reflex.vars.base import Var -from ..base import LiteralAccentColor, LiteralVariant, RadixThemesComponent from .base import LiteralTextSize, LiteralTextWeight diff --git a/reflex/components/radix/themes/typography/heading.py b/reflex/components/radix/themes/typography/heading.py index ce1eaa68f57..6657424d476 100644 --- a/reflex/components/radix/themes/typography/heading.py +++ b/reflex/components/radix/themes/typography/heading.py @@ -8,9 +8,9 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements from reflex.components.markdown.markdown import MarkdownComponentMap +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent from .base import LiteralTextAlign, LiteralTextSize, LiteralTextTrim, LiteralTextWeight diff --git a/reflex/components/radix/themes/typography/link.py b/reflex/components/radix/themes/typography/link.py index 09172b10831..ecff540debd 100644 --- a/reflex/components/radix/themes/typography/link.py +++ b/reflex/components/radix/themes/typography/link.py @@ -14,10 +14,10 @@ from reflex.components.el.elements.inline import A from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.components.next.link import NextLink +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.utils.imports import ImportDict from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent from .base import LiteralTextSize, LiteralTextTrim, LiteralTextWeight LiteralLinkUnderline = Literal["auto", "hover", "always", "none"] @@ -86,7 +86,8 @@ def create(cls, *children, **props) -> Component: if href is not None: if not len(children): - raise ValueError("Link without a child will not display") + msg = "Link without a child will not display" + raise ValueError(msg) if "as_child" not in props: # Extract props for the NextLink, the rest go to the Link/A element. diff --git a/reflex/components/radix/themes/typography/text.py b/reflex/components/radix/themes/typography/text.py index cb6527915c7..84205326c8d 100644 --- a/reflex/components/radix/themes/typography/text.py +++ b/reflex/components/radix/themes/typography/text.py @@ -11,9 +11,9 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements from reflex.components.markdown.markdown import MarkdownComponentMap +from reflex.components.radix.themes.base import LiteralAccentColor, RadixThemesComponent from reflex.vars.base import Var -from ..base import LiteralAccentColor, RadixThemesComponent from .base import LiteralTextAlign, LiteralTextSize, LiteralTextTrim, LiteralTextWeight LiteralType = Literal[ diff --git a/reflex/components/react_player/audio.py b/reflex/components/react_player/audio.py index 49a5aa31ebc..2f5cc5b6d8d 100644 --- a/reflex/components/react_player/audio.py +++ b/reflex/components/react_player/audio.py @@ -5,5 +5,3 @@ class Audio(ReactPlayer): """Audio component share with Video component.""" - - pass diff --git a/reflex/components/react_player/video.py b/reflex/components/react_player/video.py index 0823bfbb564..70b513195e2 100644 --- a/reflex/components/react_player/video.py +++ b/reflex/components/react_player/video.py @@ -5,5 +5,3 @@ class Video(ReactPlayer): """Video component share with audio component.""" - - pass diff --git a/reflex/components/recharts/charts.py b/reflex/components/recharts/charts.py index dc7fe262485..7356dd453b6 100644 --- a/reflex/components/recharts/charts.py +++ b/reflex/components/recharts/charts.py @@ -64,10 +64,11 @@ def _ensure_valid_dimension(name: str, value: Any) -> None: return if isinstance(value, Var) and issubclass(value._var_type, int): return - raise ValueError( + msg = ( f"Chart {name} must be specified as int pixels or percentage, not {value!r}. " "CSS unit dimensions are allowed on parent container." ) + raise ValueError(msg) @classmethod def create(cls, *children: Any, **props: Any) -> Component: diff --git a/reflex/components/sonner/toast.py b/reflex/components/sonner/toast.py index 6c037a14a47..89b3dbf5f37 100644 --- a/reflex/components/sonner/toast.py +++ b/reflex/components/sonner/toast.py @@ -265,7 +265,8 @@ def send_toast( props.setdefault("title", message) message = "" elif message == "" and "title" not in props and "description" not in props: - raise ValueError("Toast message or title or description must be provided.") + msg = "Toast message or title or description must be provided." + raise ValueError(msg) if props: args = LiteralVar.create(ToastProps(component_name="rx.toast", **props)) # pyright: ignore [reportCallIssue] diff --git a/reflex/components/suneditor/editor.py b/reflex/components/suneditor/editor.py index 5564f32edcd..a2cf07576bf 100644 --- a/reflex/components/suneditor/editor.py +++ b/reflex/components/suneditor/editor.py @@ -259,7 +259,8 @@ def create( """ if set_options is not None: if isinstance(set_options, Var): - raise ValueError("EditorOptions cannot be a state Var") + msg = "EditorOptions cannot be a state Var" + raise ValueError(msg) props["set_options"] = { to_camel_case(k): v for k, v in set_options.dict().items() diff --git a/reflex/components/tags/iter_tag.py b/reflex/components/tags/iter_tag.py index bbb317a5a62..69423160d35 100644 --- a/reflex/components/tags/iter_tag.py +++ b/reflex/components/tags/iter_tag.py @@ -121,7 +121,8 @@ def render_component(self) -> Component: else: # If the render function takes the index as an argument. if len(args) != 2: - raise ValueError("The render function must take 2 arguments.") + msg = "The render function must take 2 arguments." + raise ValueError(msg) component = self.render_fn(arg, index) # Nested foreach components or cond must be wrapped in fragments. @@ -131,7 +132,8 @@ def render_component(self) -> Component: component = _into_component_once(component) if component is None: - raise ValueError("The render function must return a component.") + msg = "The render function must return a component." + raise ValueError(msg) # Set the component key. if component.key is None: diff --git a/reflex/config.py b/reflex/config.py index d222b150aa6..0d30944def5 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -214,12 +214,10 @@ def get_default_value_for_field(field: dataclasses.Field) -> Any: """ if field.default != dataclasses.MISSING: return field.default - elif field.default_factory != dataclasses.MISSING: + if field.default_factory != dataclasses.MISSING: return field.default_factory() - else: - raise ValueError( - f"Missing value for environment variable {field.name} and no default value found" - ) + msg = f"Missing value for environment variable {field.name} and no default value found" + raise ValueError(msg) # TODO: Change all interpret_.* signatures to value: str, field: dataclasses.Field once we migrate rx.Config to dataclasses @@ -241,9 +239,10 @@ def interpret_boolean_env(value: str, field_name: str) -> bool: if value.lower() in true_values: return True - elif value.lower() in false_values: + if value.lower() in false_values: return False - raise EnvironmentVarValueError(f"Invalid boolean value: {value} for {field_name}") + msg = f"Invalid boolean value: {value} for {field_name}" + raise EnvironmentVarValueError(msg) def interpret_int_env(value: str, field_name: str) -> int: @@ -262,9 +261,8 @@ def interpret_int_env(value: str, field_name: str) -> int: try: return int(value) except ValueError as ve: - raise EnvironmentVarValueError( - f"Invalid integer value: {value} for {field_name}" - ) from ve + msg = f"Invalid integer value: {value} for {field_name}" + raise EnvironmentVarValueError(msg) from ve def interpret_existing_path_env(value: str, field_name: str) -> ExistingPath: @@ -282,7 +280,8 @@ def interpret_existing_path_env(value: str, field_name: str) -> ExistingPath: """ path = Path(value) if not path.exists(): - raise EnvironmentVarValueError(f"Path does not exist: {path} for {field_name}") + msg = f"Path does not exist: {path} for {field_name}" + raise EnvironmentVarValueError(msg) return path @@ -316,9 +315,8 @@ def interpret_enum_env(value: str, field_type: GenericType, field_name: str) -> try: return field_type(value) except ValueError as ve: - raise EnvironmentVarValueError( - f"Invalid enum value: {value} for {field_name}" - ) from ve + msg = f"Invalid enum value: {value} for {field_name}" + raise EnvironmentVarValueError(msg) from ve def interpret_env_var_value( @@ -340,21 +338,20 @@ def interpret_env_var_value( field_type = value_inside_optional(field_type) if is_union(field_type): - raise ValueError( - f"Union types are not supported for environment variables: {field_name}." - ) + msg = f"Union types are not supported for environment variables: {field_name}." + raise ValueError(msg) if field_type is bool: return interpret_boolean_env(value, field_name) - elif field_type is str: + if field_type is str: return value - elif field_type is int: + if field_type is int: return interpret_int_env(value, field_name) - elif field_type is Path: + if field_type is Path: return interpret_path_env(value, field_name) - elif field_type is ExistingPath: + if field_type is ExistingPath: return interpret_existing_path_env(value, field_name) - elif get_origin(field_type) is list: + if get_origin(field_type) is list: return [ interpret_env_var_value( v, @@ -363,13 +360,11 @@ def interpret_env_var_value( ) for i, v in enumerate(value.split(":")) ] - elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum): + if inspect.isclass(field_type) and issubclass(field_type, enum.Enum): return interpret_enum_env(value, field_type, field_name) - else: - raise ValueError( - f"Invalid type for environment variable {field_name}: {field_type}. This is probably an issue in Reflex." - ) + msg = f"Invalid type for environment variable {field_name}: {field_type}. This is probably an issue in Reflex." + raise ValueError(msg) T = TypeVar("T") @@ -982,9 +977,8 @@ def __init__(self, *args, **kwargs): self.state_manager_mode == constants.StateManagerMode.REDIS and not self.redis_url ): - raise ConfigError( - f"{self._prefixes[0]}REDIS_URL is required when using the redis state manager." - ) + msg = f"{self._prefixes[0]}REDIS_URL is required when using the redis state manager." + raise ConfigError(msg) @property def app_module(self) -> ModuleType | None: @@ -1008,7 +1002,7 @@ def module(self) -> str: """ if self.app_module_import is not None: return self.app_module_import - return ".".join([self.app_name, self.app_name]) + return self.app_name + "." + self.app_name def update_from_env(self) -> dict[str, Any]: """Update the config values based on set environment variables. diff --git a/reflex/constants/base.py b/reflex/constants/base.py index a0a0b415196..7d896aca2a2 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -33,11 +33,11 @@ class Dirs(SimpleNamespace): # The name of the utils file. UTILS = "utils" # The name of the state file. - STATE_PATH = "/".join([UTILS, "state"]) + STATE_PATH = UTILS + "/state" # The name of the components file. - COMPONENTS_PATH = "/".join([UTILS, "components"]) + COMPONENTS_PATH = UTILS + "/components" # The name of the contexts file. - CONTEXTS_PATH = "/".join([UTILS, "context"]) + CONTEXTS_PATH = UTILS + "/context" # The name of the output static directory. STATIC = "_static" # The name of the public html directory served at "/" diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 3a59c631221..01bb81e993d 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -176,6 +176,13 @@ class MemoizationMode: ARIA_UNDERSCORE = "aria_" ARIA_DASH = "aria-" +SPECIAL_ATTRS = ( + DATA_UNDERSCORE, + DATA_DASH, + ARIA_UNDERSCORE, + ARIA_DASH, +) + class SpecialAttributes(enum.Enum): """Special attributes for components. @@ -194,9 +201,4 @@ def is_special(cls, attr: str) -> bool: Returns: True if the attribute is special. """ - return ( - attr.startswith(DATA_UNDERSCORE) - or attr.startswith(DATA_DASH) - or attr.startswith(ARIA_UNDERSCORE) - or attr.startswith(ARIA_DASH) - ) + return attr.startswith(SPECIAL_ATTRS) diff --git a/reflex/custom_components/custom_components.py b/reflex/custom_components/custom_components.py index 9a627db33ab..914cec8aa81 100644 --- a/reflex/custom_components/custom_components.py +++ b/reflex/custom_components/custom_components.py @@ -35,7 +35,6 @@ def set_loglevel(ctx: Any, self: Any, value: str | None): @click.group def custom_components_cli(): """CLI for creating custom components.""" - pass loglevel_option = click.option( @@ -575,7 +574,7 @@ def _validate_url_with_protocol_prefix(url: str | None) -> bool: Returns: Whether the entered URL is acceptable. """ - return not url or (url.startswith("http://") or url.startswith("https://")) + return not url or (url.startswith(("http://", "https://"))) def _get_file_from_prompt_in_loop() -> tuple[bytes, str] | None: diff --git a/reflex/event.py b/reflex/event.py index dc3865d3456..615a2a9e47d 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -228,9 +228,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> EventSpec: ), Unset, ): - raise EventHandlerTypeError( - f"Event handler {self.fn.__name__} received repeated argument {repeated_arg}." - ) + msg = f"Event handler {self.fn.__name__} received repeated argument {repeated_arg}." + raise EventHandlerTypeError(msg) if not isinstance( extra_arg := next( @@ -238,9 +237,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> EventSpec: ), Unset, ): - raise EventHandlerTypeError( + msg = ( f"Event handler {self.fn.__name__} received extra argument {extra_arg}." ) + raise EventHandlerTypeError(msg) fn_args = fn_args[: len(args)] + list(kwargs) @@ -257,9 +257,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> EventSpec: try: values.append(LiteralVar.create(arg)) except TypeError as e: - raise EventHandlerTypeError( - f"Arguments to event handlers must be Vars or JSON-serializable. Got {arg} of type {type(arg)}." - ) from e + msg = f"Arguments to event handlers must be Vars or JSON-serializable. Got {arg} of type {type(arg)}." + raise EventHandlerTypeError(msg) from e payload = tuple(zip(fn_args, values, strict=False)) # Return the event spec. @@ -353,9 +352,8 @@ def add_args(self, *args: Var) -> EventSpec: for arg in args: values.append(LiteralVar.create(value=arg)) # noqa: PERF401, RUF100 except TypeError as e: - raise EventHandlerTypeError( - f"Arguments to event handlers must be Vars or JSON-serializable. Got {arg} of type {type(arg)}." - ) from e + msg = f"Arguments to event handlers must be Vars or JSON-serializable. Got {arg} of type {type(arg)}." + raise EventHandlerTypeError(msg) from e new_payload = tuple(zip(fn_args, values, strict=False)) return self.with_args(self.args + new_payload) @@ -408,7 +406,8 @@ def __call__(self, *args, **kwargs) -> EventSpec: from reflex.utils.exceptions import EventHandlerTypeError if self.fn is None: - raise EventHandlerTypeError("CallableEventSpec has no associated function.") + msg = "CallableEventSpec has no associated function." + raise EventHandlerTypeError(msg) return self.fn(*args, **kwargs) @@ -453,7 +452,7 @@ def create( if isinstance(value, Var): if isinstance(value, EventChainVar): return value - elif isinstance(value, EventVar): + if isinstance(value, EventVar): value = [value] elif safe_issubclass(value._var_type, (EventChain, EventSpec)): return cls.create( @@ -463,9 +462,8 @@ def create( **event_chain_kwargs, ) else: - raise ValueError( - f"Invalid event chain: {value!s} of type {value._var_type}" - ) + msg = f"Invalid event chain: {value!s} of type {value._var_type}" + raise ValueError(msg) elif isinstance(value, EventChain): # Trust that the caller knows what they're doing passing an EventChain directly return value @@ -485,15 +483,17 @@ def create( # Call the lambda to get the event chain. result = call_event_fn(v, args_spec, key=key) if isinstance(result, Var): - raise ValueError( + msg = ( f"Invalid event chain: {v}. Cannot use a Var-returning " "lambda inside an EventChain list." ) + raise ValueError(msg) events.extend(result) elif isinstance(v, EventVar): events.append(v) else: - raise ValueError(f"Invalid event: {v}") + msg = f"Invalid event: {v}" + raise ValueError(msg) # If the input is a callable, create an event chain. elif isinstance(value, Callable): @@ -507,7 +507,8 @@ def create( # Otherwise, raise an error. else: - raise ValueError(f"Invalid event chain: {value}") + msg = f"Invalid event chain: {value}" + raise ValueError(msg) # Add args to the event specs if necessary. events = [ @@ -783,9 +784,11 @@ def as_event_spec(self, handler: EventHandler) -> EventSpec: on_upload_progress, self.on_upload_progress_args_spec ) else: - raise ValueError(f"{on_upload_progress} is not a valid event handler.") + msg = f"{on_upload_progress} is not a valid event handler." + raise ValueError(msg) if isinstance(events, Var): - raise ValueError(f"{on_upload_progress} cannot return a var {events}.") + msg = f"{on_upload_progress} cannot return a var {events}." + raise ValueError(msg) on_upload_progress_chain = EventChain( events=[*events], args_spec=self.on_upload_progress_args_spec, @@ -1081,7 +1084,8 @@ def download( if isinstance(url, str): if not url.startswith("/"): - raise ValueError("The URL argument should start with a /") + msg = "The URL argument should start with a /" + raise ValueError(msg) # if filename is not provided, infer it from url if filename is None: @@ -1092,7 +1096,8 @@ def download( if data is not None: if url is not None: - raise ValueError("Cannot provide both URL and data to download.") + msg = "Cannot provide both URL and data to download." + raise ValueError(msg) if isinstance(data, str): # Caller provided a plain text string to download. @@ -1115,9 +1120,8 @@ def download( b64_data = b64encode(data).decode("utf-8") url = "data:application/octet-stream;base64," + b64_data else: - raise ValueError( - f"Invalid data type {type(data)} for download. Use `str` or `bytes`." - ) + msg = f"Invalid data type {type(data)} for download. Use `str` or `bytes`." + raise ValueError(msg) return server_side( "_download", @@ -1323,23 +1327,21 @@ def _check_event_args_subclass_of_callback( except TypeError as te: callback_name_context = f" of {callback_name}" if callback_name else "" key_context = f" for {key}" if key else "" - raise TypeError( - f"Could not compare types {args_types_without_vars[i]} and {callback_param_name_to_type[arg]} for argument {arg}{callback_name_context}{key_context}." - ) from te + msg = f"Could not compare types {args_types_without_vars[i]} and {callback_param_name_to_type[arg]} for argument {arg}{callback_name_context}{key_context}." + raise TypeError(msg) from te if compare_result: type_match_found[arg] = True continue - else: - type_match_found[arg] = False - as_annotated_in = ( - f" as annotated in {callback_name}" if callback_name else "" - ) - delayed_exceptions.append( - EventHandlerArgTypeMismatchError( - f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {callback_param_name_to_type[arg]}{as_annotated_in} instead." - ) + type_match_found[arg] = False + as_annotated_in = ( + f" as annotated in {callback_name}" if callback_name else "" + ) + delayed_exceptions.append( + EventHandlerArgTypeMismatchError( + f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {callback_param_name_to_type[arg]}{as_annotated_in} instead." ) + ) if all(type_match_found.values()): delayed_exceptions.clear() @@ -1495,8 +1497,7 @@ def resolve_annotation(annotations: dict[str, Any], arg_name: str, spec: ArgsSpe if annotation is None: if not isinstance(spec, types.LambdaType): raise MissingAnnotationError(var_name=arg_name) - else: - return dict[str, dict] + return dict[str, dict] return annotation @@ -1570,12 +1571,13 @@ def check_fn_match_arg_spec( number_of_event_args = len(parsed_event_args) if number_of_user_args - number_of_user_default_args > number_of_event_args: - raise EventFnArgMismatchError( + msg = ( f"Event {key} only provides {number_of_event_args} arguments, but " f"{func_name or user_func} requires at least {number_of_user_args - number_of_user_default_args} " "arguments to be passed to the event handler.\n" "See https://reflex.dev/docs/events/event-arguments/" ) + raise EventFnArgMismatchError(msg) def call_event_fn( @@ -1630,9 +1632,8 @@ def call_event_fn( # Make sure the event spec is valid. if not isinstance(e, EventSpec): - raise EventHandlerValueError( - f"Lambda {fn} returned an invalid event spec: {e}." - ) + msg = f"Lambda {fn} returned an invalid event spec: {e}." + raise EventHandlerValueError(msg) # Add the event spec to the chain. events.append(e) @@ -1696,7 +1697,8 @@ def fix_events( if isinstance(e, EventHandler): e = e() if not isinstance(e, EventSpec): - raise ValueError(f"Unexpected event type, {type(e)}.") + msg = f"Unexpected event type, {type(e)}." + raise ValueError(msg) name = format.format_event_handler(e.handler) payload = {k._js_expr: v._decode() for k, v in e.args} @@ -1749,9 +1751,8 @@ def bool(self) -> NoReturn: Raises: TypeError: EventVar cannot be converted to a boolean. """ - raise TypeError( - f"Cannot convert {self._js_expr} of type {type(self).__name__} to bool." - ) + msg = f"Cannot convert {self._js_expr} of type {type(self).__name__} to bool." + raise TypeError(msg) @dataclasses.dataclass( @@ -1798,9 +1799,8 @@ def no_args(): try: value = call_event_handler(value, no_args) except EventFnArgMismatchError: - raise EventFnArgMismatchError( - f"Event handler {value.fn.__qualname__} used inside of a rx.cond() must not take any arguments." - ) from None + msg = f"Event handler {value.fn.__qualname__} used inside of a rx.cond() must not take any arguments." + raise EventFnArgMismatchError(msg) from None return cls( _js_expr="", @@ -1835,9 +1835,8 @@ def bool(self) -> NoReturn: Raises: TypeError: EventChainVar cannot be converted to a boolean. """ - raise TypeError( - f"Cannot convert {self._js_expr} of type {type(self).__name__} to bool." - ) + msg = f"Cannot convert {self._js_expr} of type {type(self).__name__} to bool." + raise TypeError(msg) @dataclasses.dataclass( @@ -1906,9 +1905,8 @@ def create( invocation = value.invocation if invocation is not None and not isinstance(invocation, FunctionVar): - raise ValueError( - f"EventChain invocation must be a FunctionVar, got {invocation!s} of type {invocation._var_type!s}." - ) + msg = f"EventChain invocation must be a FunctionVar, got {invocation!s} of type {invocation._var_type!s}." + raise ValueError(msg) return cls( _js_expr="", @@ -2142,12 +2140,12 @@ def wrapper( if not inspect.iscoroutinefunction( func ) and not inspect.isasyncgenfunction(func): - raise TypeError( - "Background task must be async function or generator." - ) + msg = "Background task must be async function or generator." + raise TypeError(msg) setattr(func, BACKGROUND_TASK_MARKER, True) if getattr(func, "__name__", "").startswith("_"): - raise ValueError("Event handlers cannot be private.") + msg = "Event handlers cannot be private." + raise ValueError(msg) qualname: str | None = getattr(func, "__qualname__", None) diff --git a/reflex/experimental/__init__.py b/reflex/experimental/__init__.py index c3a9ecaaa41..c755bd50300 100644 --- a/reflex/experimental/__init__.py +++ b/reflex/experimental/__init__.py @@ -6,9 +6,9 @@ from reflex.components.props import PropsBase from reflex.components.radix.themes.components.progress import progress as progress from reflex.components.sonner.toast import toast as toast +from reflex.utils.console import warn +from reflex.utils.misc import run_in_thread -from ..utils.console import warn -from ..utils.misc import run_in_thread from . import hooks as hooks from .client_state import ClientStateVar as ClientStateVar from .layout import layout as layout diff --git a/reflex/experimental/client_state.py b/reflex/experimental/client_state.py index 21fbcba3aea..631b2025607 100644 --- a/reflex/experimental/client_state.py +++ b/reflex/experimental/client_state.py @@ -111,7 +111,8 @@ def create( var_name = get_unique_variable_name() id_name = "id_" + get_unique_variable_name() if not isinstance(var_name, str): - raise ValueError("var_name must be a string.") + msg = "var_name must be a string." + raise ValueError(msg) if default is NoValue: default_var = Var(_js_expr="") elif not isinstance(default, Var): @@ -271,7 +272,8 @@ def retrieve(self, callback: EventHandler | Callable | None = None) -> EventSpec ValueError: If the ClientStateVar is not global. """ if not self._global_ref: - raise ValueError("ClientStateVar must be global to retrieve the value.") + msg = "ClientStateVar must be global to retrieve the value." + raise ValueError(msg) return run_script(_client_state_ref(self._getter_name), callback=callback) def push(self, value: Any) -> EventSpec: @@ -289,6 +291,7 @@ def push(self, value: Any) -> EventSpec: ValueError: If the ClientStateVar is not global. """ if not self._global_ref: - raise ValueError("ClientStateVar must be global to push the value.") + msg = "ClientStateVar must be global to push the value." + raise ValueError(msg) value = Var.create(value) return run_script(f"{_client_state_ref(self._setter_name)}({value})") diff --git a/reflex/istate/manager.py b/reflex/istate/manager.py index de1044ce95c..dfe10cef333 100644 --- a/reflex/istate/manager.py +++ b/reflex/istate/manager.py @@ -66,9 +66,8 @@ def create(cls, state: type[BaseState]): lock_expiration=config.redis_lock_expiration, lock_warning_threshold=config.redis_lock_warning_threshold, ) - raise InvalidStateManagerModeError( - f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}" - ) + msg = f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}" + raise InvalidStateManagerModeError(msg) @abstractmethod async def get_state(self, token: str) -> BaseState: @@ -80,7 +79,6 @@ async def get_state(self, token: str) -> BaseState: Returns: The state for the token. """ - pass @abstractmethod async def set_state(self, token: str, state: BaseState): @@ -90,7 +88,6 @@ async def set_state(self, token: str, state: BaseState): token: The token to set the state for. state: The state to set. """ - pass @abstractmethod @contextlib.asynccontextmanager @@ -145,7 +142,6 @@ async def set_state(self, token: str, state: BaseState): token: The token to set the state for. state: The state to set. """ - pass @override @contextlib.asynccontextmanager @@ -269,6 +265,7 @@ async def load_state(self, token: str) -> BaseState | None: return BaseState._deserialize(fp=file) except Exception: pass + return None async def populate_substates( self, client_token: str, state: BaseState, root_state: BaseState @@ -449,9 +446,8 @@ def __post_init__(self): InvalidLockWarningThresholdError: If the lock warning threshold is invalid. """ if self.lock_warning_threshold >= (lock_expiration := self.lock_expiration): - raise InvalidLockWarningThresholdError( - f"The lock warning threshold({self.lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})." - ) + msg = f"The lock warning threshold({self.lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})." + raise InvalidLockWarningThresholdError(msg) def _get_required_state_classes( self, @@ -557,9 +553,8 @@ async def get_state( # Get the State class associated with the given path. state_cls = self.state.get_class_substate(state_path) else: - raise RuntimeError( - f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}" - ) + msg = f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}" + raise RuntimeError(msg) # Determine which states we already have. flat_state_tree: dict[str, BaseState] = ( @@ -601,11 +596,12 @@ async def get_state( ) parent_state = flat_state_tree.get(parent_state_name) if parent_state is None: - raise RuntimeError( + msg = ( f"Parent state for {state.get_full_name()} was not found " "in the state tree, but should have already been fetched. " - "This is a bug", + "This is a bug" ) + raise RuntimeError(msg) parent_state.substates[state_name] = state state.parent_state = parent_state @@ -638,12 +634,13 @@ async def set_state( lock_id is not None and await self.redis.get(self._lock_key(token)) != lock_id ): - raise LockExpiredError( + msg = ( f"Lock expired for token {token} while processing. Consider increasing " f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) " "or use `@rx.event(background=True)` decorator for long-running tasks." ) - elif lock_id is not None: + raise LockExpiredError(msg) + if lock_id is not None: time_taken = self.lock_expiration / 1000 - ( await self.redis.ttl(self._lock_key(token)) ) @@ -657,9 +654,8 @@ async def set_state( client_token, substate_name = _split_substate_key(token) # If the substate name on the token doesn't match the instance name, it cannot have a parent. if state.parent_state is not None and state.get_full_name() != substate_name: - raise RuntimeError( - f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}." - ) + msg = f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}." + raise RuntimeError(msg) # Recursively set_state on all known substates. tasks = [ diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index e6a4ae62665..c54e924d8c8 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -122,9 +122,8 @@ async def __aenter__(self) -> StateProxy: self._self_actx_lock.locked() and current_task == self._self_actx_lock_holder ): - raise ImmutableStateError( - "The state is already mutable. Do not nest `async with self` blocks." - ) + msg = "The state is already mutable. Do not nest `async with self` blocks." + raise ImmutableStateError(msg) from reflex.state import _substate_key @@ -173,7 +172,8 @@ def __enter__(self): Raises: TypeError: always, because only async contextmanager protocol is supported. """ - raise TypeError("Background task must use `async with self` to modify state.") + msg = "Background task must use `async with self` to modify state." + raise TypeError(msg) def __exit__(self, *exc_info: Any) -> None: """Exit the regular context manager protocol. @@ -181,7 +181,6 @@ def __exit__(self, *exc_info: Any) -> None: Args: exc_info: The exception info tuple. """ - pass def __getattr__(self, name: str) -> Any: """Get the attribute from the underlying state instance. @@ -196,10 +195,11 @@ def __getattr__(self, name: str) -> Any: ImmutableStateError: If the state is not in mutable mode. """ if name in ["substates", "parent_state"] and not self._is_mutable(): - raise ImmutableStateError( + msg = ( "Background task StateProxy is immutable outside of a context " "manager. Use `async with self` to modify state." ) + raise ImmutableStateError(msg) value = super().__getattr__(name) if not name.startswith("_self_") and isinstance(value, MutableProxy): @@ -243,10 +243,11 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) return - raise ImmutableStateError( + msg = ( "Background task StateProxy is immutable outside of a context " "manager. Use `async with self` to modify state." ) + raise ImmutableStateError(msg) def get_substate(self, path: Sequence[str]) -> BaseState: """Only allow substate access with lock held. @@ -261,10 +262,11 @@ def get_substate(self, path: Sequence[str]) -> BaseState: ImmutableStateError: If the state is not in mutable mode. """ if not self._is_mutable(): - raise ImmutableStateError( + msg = ( "Background task StateProxy is immutable outside of a context " "manager. Use `async with self` to modify state." ) + raise ImmutableStateError(msg) return self.__wrapped__.get_substate(path) async def get_state(self, state_cls: type[BaseState]) -> BaseState: @@ -280,10 +282,11 @@ async def get_state(self, state_cls: type[BaseState]) -> BaseState: ImmutableStateError: If the state is not in mutable mode. """ if not self._is_mutable(): - raise ImmutableStateError( + msg = ( "Background task StateProxy is immutable outside of a context " "manager. Use `async with self` to modify state." ) + raise ImmutableStateError(msg) return type(self)( await self.__wrapped__.get_state(state_cls), parent_state_proxy=self ) @@ -323,7 +326,8 @@ def __setattr__(self, name: str, value: Any) -> None: # Special case attributes of the proxy itself, not applied to the wrapped object. super().__setattr__(name, value) return - raise NotImplementedError("This is a read-only state proxy.") + msg = "This is a read-only state proxy." + raise NotImplementedError(msg) def mark_dirty(self): """Mark the state as dirty. @@ -331,7 +335,8 @@ def mark_dirty(self): Raises: NotImplementedError: Always raised when trying to mark the proxied state as dirty. """ - raise NotImplementedError("This is a read-only state proxy.") + msg = "This is a read-only state proxy." + raise NotImplementedError(msg) class MutableProxy(wrapt.ObjectProxy): @@ -460,6 +465,7 @@ def _mark_dirty( self._self_state._mark_dirty() if wrapped is not None: return wrapped(*args, **(kwargs or {})) + return None @classmethod def _is_mutable_type(cls, value: Any) -> bool: @@ -748,10 +754,11 @@ def _mark_dirty( ImmutableStateError: if the StateProxy is not mutable. """ if not self._self_state._is_mutable(): - raise ImmutableStateError( + msg = ( "Background task StateProxy is immutable outside of a context " "manager. Use `async with self` to modify state." ) + raise ImmutableStateError(msg) return super()._mark_dirty( wrapped=wrapped, instance=instance, args=args, kwargs=kwargs ) diff --git a/reflex/model.py b/reflex/model.py index 2add4cbcc67..6c82b199a22 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -83,7 +83,8 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine: conf = get_config() url = url or conf.db_url if url is None: - raise ValueError("No database url configured") + msg = "No database url configured" + raise ValueError(msg) global _ENGINE if url in _ENGINE: @@ -125,7 +126,8 @@ def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine: f"db_url `{_safe_db_url_for_logging(conf.db_url)}`." ) if url is None: - raise ValueError("No async database url configured") + msg = "No async database url configured" + raise ValueError(msg) global _ASYNC_ENGINE if url in _ASYNC_ENGINE: @@ -271,7 +273,7 @@ def _dict_recursive(cls, value: Any): """ if hasattr(value, "dict"): return value.dict() - elif isinstance(value, list): + if isinstance(value, list): return [cls._dict_recursive(item) for item in value] return value @@ -481,7 +483,7 @@ def migrate(cls, autogenerate: bool = False) -> bool | None: None - indicating the process was skipped. """ if not environment.ALEMBIC_CONFIG.get().exists(): - return + return None with cls.get_db_engine().connect() as connection: cls._alembic_upgrade(connection=connection) diff --git a/reflex/reflex.py b/reflex/reflex.py index 4e109739d8d..9f4f543e289 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -36,7 +36,6 @@ def set_loglevel(ctx: click.Context, self: click.Parameter, value: str | None): @click.version_option(constants.Reflex.VERSION, message="%(version)s") def cli(): """Reflex CLI to create, run, and deploy apps.""" - pass loglevel_option = click.option( @@ -235,7 +234,8 @@ def _run( exec.run_backend_prod, ) if not setup_frontend or not frontend_cmd or not backend_cmd: - raise ValueError(f"Invalid env: {env}. Must be DEV or PROD.") + msg = f"Invalid env: {env}. Must be DEV or PROD." + raise ValueError(msg) # Post a telemetry event. telemetry.send(f"run-{env.value}") @@ -481,13 +481,11 @@ def logout(): @click.group def db_cli(): """Subcommands for managing the database schema.""" - pass @click.group def script_cli(): """Subcommands for running helper scripts.""" - pass def _skip_compile(): diff --git a/reflex/route.py b/reflex/route.py index 3f49f66e90d..16f6c4d6b9e 100644 --- a/reflex/route.py +++ b/reflex/route.py @@ -18,11 +18,13 @@ def verify_route_validity(route: str) -> None: """ pattern = catchall_in_route(route) if pattern and not route.endswith(pattern): - raise ValueError(f"Catch-all must be the last part of the URL: {route}") + msg = f"Catch-all must be the last part of the URL: {route}" + raise ValueError(msg) if route == "api" or route.startswith("api/"): - raise ValueError( + msg = ( f"Cannot have a route prefixed with 'api/': {route} (conflicts with NextJS)" ) + raise ValueError(msg) def get_route_args(route: str) -> dict[str, str]: @@ -48,9 +50,8 @@ def add_route_arg(match: re.Match[str], type_: str): """ arg_name = match.groups()[0] if arg_name in args: - raise ValueError( - f"Arg name [{arg_name}] is used more than once in this URL" - ) + msg = f"Arg name [{arg_name}] is used more than once in this URL" + raise ValueError(msg) args[arg_name] = type_ # Regex to check for route args. @@ -136,7 +137,4 @@ def replace_brackets_with_keywords(input_string: str) -> str: r"\[\[.+?\]\]", constants.RouteRegex.DOUBLE_SEGMENT, output_string ) # Replace [] with __SINGLE_SEGMENT__ - output_string = re.sub( - r"\[.+?\]", constants.RouteRegex.SINGLE_SEGMENT, output_string - ) - return output_string + return re.sub(r"\[.+?\]", constants.RouteRegex.SINGLE_SEGMENT, output_string) diff --git a/reflex/state.py b/reflex/state.py index 5ab6a0bb9de..e2c516e390a 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -148,7 +148,8 @@ async def _no_chain_background_task_gen(*args, **kwargs): return _no_chain_background_task_gen - raise TypeError(f"{fn} is marked as a background task, but is not async.") + msg = f"{fn} is marked as a background task, but is not async." + raise TypeError(msg) def _substate_key( @@ -233,22 +234,19 @@ def __call__(self, *args: Any) -> EventSpec: if args: if not isinstance(args[0], str): - raise EventHandlerValueError( - f"Var name must be passed as a string, got {args[0]!r}" - ) + msg = f"Var name must be passed as a string, got {args[0]!r}" + raise EventHandlerValueError(msg) handler = getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) # Check that the requested Var setter exists on the State at compile time. if handler is None: - raise AttributeError( - f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`" - ) + msg = f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`" + raise AttributeError(msg) if asyncio.iscoroutinefunction(handler.fn): - raise NotImplementedError( - f"Setter for {args[0]} is async, which is not supported." - ) + msg = f"Setter for {args[0]} is async, which is not supported." + raise NotImplementedError(msg) return super().__call__(*args) @@ -406,14 +404,14 @@ def __init__( from reflex.utils.exceptions import ReflexRuntimeError if not _reflex_internal_init and not is_testing_env(): - raise ReflexRuntimeError( + msg = ( "State classes should not be instantiated directly in a Reflex app. " "See https://reflex.dev/docs/state/ for further information." ) + raise ReflexRuntimeError(msg) if type(self)._mixin: - raise ReflexRuntimeError( - f"{type(self).__name__} is a state mixin and cannot be instantiated directly." - ) + msg = f"{type(self).__name__} is a state mixin and cannot be instantiated directly." + raise ReflexRuntimeError(msg) kwargs["parent_state"] = parent_state super().__init__() for name, value in kwargs.items(): @@ -462,10 +460,11 @@ def _validate_module_name(cls) -> None: NameError: If the module name is invalid. """ if "___" in cls.__module__: - raise NameError( + msg = ( "The module name of a State class cannot contain '___'. " "Please rename the module." ) + raise NameError(msg) @classmethod def __init_subclass__(cls, mixin: bool = False, **kwargs): @@ -515,10 +514,11 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): # Check if another substate class with the same name has already been defined. if cls.get_name() in {c.get_name() for c in parent_state.class_subclasses}: # This should not happen, since we have added module prefix to state names in #3214 - raise StateValueError( + msg = ( f"The substate class '{cls.get_name()}' has been defined multiple times. " "Shadowing substate classes is not allowed." ) + raise StateValueError(msg) # Track this new subclass in the parent state's subclasses set. parent_state.class_subclasses.add(cls) @@ -832,9 +832,8 @@ def _check_overridden_methods(cls): overridden_methods.add(method.__name__) for method_name in overridden_methods: - raise EventHandlerShadowsBuiltInStateMethodError( - f"The event handler name `{method_name}` shadows a builtin State method; use a different name instead" - ) + msg = f"The event handler name `{method_name}` shadows a builtin State method; use a different name instead" + raise EventHandlerShadowsBuiltInStateMethodError(msg) @classmethod def _check_overridden_basevars(cls): @@ -845,9 +844,8 @@ def _check_overridden_basevars(cls): """ for computed_var_ in cls._get_computed_vars(): if computed_var_._js_expr in cls.__annotations__: - raise ComputedVarShadowsBaseVarsError( - f"The computed var name `{computed_var_._js_expr}` shadows a base var in {cls.__module__}.{cls.__name__}; use a different name instead" - ) + msg = f"The computed var name `{computed_var_._js_expr}` shadows a base var in {cls.__module__}.{cls.__name__}; use a different name instead" + raise ComputedVarShadowsBaseVarsError(msg) @classmethod def _check_overridden_computed_vars(cls) -> None: @@ -861,9 +859,8 @@ def _check_overridden_computed_vars(cls) -> None: continue name = cv._js_expr if name in cls.inherited_vars or name in cls.inherited_backend_vars: - raise ComputedVarShadowsStateVarError( - f"The computed var name `{cv._js_expr}` shadows a var in {cls.__module__}.{cls.__name__}; use a different name instead" - ) + msg = f"The computed var name `{cv._js_expr}` shadows a var in {cls.__module__}.{cls.__name__}; use a different name instead" + raise ComputedVarShadowsStateVarError(msg) @classmethod def get_skip_vars(cls) -> set[str]: @@ -901,7 +898,8 @@ def get_parent_state(cls) -> type[BaseState] | None: if issubclass(base, BaseState) and base is not BaseState and not base._mixin ] if len(parent_states) >= 2: - raise ValueError(f"Only one parent state is allowed {parent_states}.") + msg = f"Only one parent state is allowed {parent_states}." + raise ValueError(msg) # The first non-mixin state in the mro is our parent. for base in cls.mro()[1:]: if not issubclass(base, BaseState) or base._mixin: @@ -953,7 +951,7 @@ def get_full_name(cls) -> str: name = cls.get_name() parent_state = cls.get_parent_state() if parent_state is not None: - name = ".".join((parent_state.get_full_name(), name)) + name = parent_state.get_full_name() + "." + name return name @classmethod @@ -982,7 +980,8 @@ def get_class_substate(cls, path: Sequence[str] | str) -> type[BaseState]: for substate in cls.get_substates(): if path[0] == substate.get_name(): return substate.get_class_substate(path[1:]) - raise ValueError(f"Invalid path: {path}") + msg = f"Invalid path: {path}" + raise ValueError(msg) @classmethod def get_class_var(cls, path: Sequence[str]) -> Any: @@ -1000,7 +999,8 @@ def get_class_var(cls, path: Sequence[str]) -> Any: path, name = path[:-1], path[-1] substate = cls.get_class_substate(tuple(path)) if not hasattr(substate, name): - raise ValueError(f"Invalid path: {path}") + msg = f"Invalid path: {path}" + raise ValueError(msg) return getattr(substate, name) @classmethod @@ -1029,12 +1029,13 @@ def _init_var(cls, prop: Var): from reflex.utils.exceptions import VarTypeError if not types.is_valid_var_type(prop._var_type): - raise VarTypeError( + msg = ( "State vars must be of a serializable type. " "Valid types include strings, numbers, booleans, lists, " "dictionaries, dataclasses, datetime objects, and pydantic models. " f'Found var "{prop._js_expr}" with type {prop._var_type}.' ) + raise VarTypeError(msg) cls._set_var(prop) if cls.is_user_defined() and get_config().state_auto_setters: cls._create_setter(prop) @@ -1056,9 +1057,8 @@ def add_var(cls, name: str, type_: Any, default_value: Any = None): NameError: if a variable of this name already exists """ if name in cls.__fields__: - raise NameError( - f"The variable '{name}' already exist. Use a different name" - ) + msg = f"The variable '{name}' already exist. Use a different name" + raise NameError(msg) # create the variable based on name and type var = Var( @@ -1260,9 +1260,8 @@ def _check_overwritten_dynamic_args(cls, args: list[str]): arg in cls.computed_vars and not isinstance(cls.computed_vars[arg], DynamicRouteVar) ) or arg in cls.base_vars: - raise DynamicRouteArgShadowsStateVarError( - f"Dynamic route arg '{arg}' is shadowing an existing var in {cls.__module__}.{cls.__name__}" - ) + msg = f"Dynamic route arg '{arg}' is shadowing an existing var in {cls.__module__}.{cls.__name__}" + raise DynamicRouteArgShadowsStateVarError(msg) for substate in cls.get_substates(): substate._check_overwritten_dynamic_args(args) @@ -1363,10 +1362,11 @@ def __setattr__(self, name: str, value: Any): f"_{getattr(type(self), '__original_name__', type(self).__name__)}__" ) ): - raise SetUndefinedStateVarError( + msg = ( f"The state variable '{name}' has not been defined in '{type(self).__name__}'. " f"All state variables must be declared before they can be set." ) + raise SetUndefinedStateVarError(msg) fields = self.get_fields() @@ -1470,7 +1470,8 @@ def get_substate(self, path: Sequence[str]) -> BaseState: return self path = path[1:] if path[0] not in self.substates: - raise ValueError(f"Invalid path: {path}") + msg = f"Invalid path: {path}" + raise ValueError(msg) return self.substates[path[0]].get_substate(path[1:]) @classmethod @@ -1517,10 +1518,11 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: # Then get the target state and all its substates. state_manager = get_state_manager() if not isinstance(state_manager, StateManagerRedis): - raise RuntimeError( + msg = ( f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " - "(All states should already be available -- this is likely a bug).", + "(All states should already be available -- this is likely a bug)." ) + raise RuntimeError(msg) state_in_redis = await state_manager.get_state( token=_substate_key(self.router.session.client_token, state_cls), top_level=False, @@ -1528,9 +1530,8 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: ) if not isinstance(state_in_redis, state_cls): - raise StateMismatchError( - f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}." - ) + msg = f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}." + raise StateMismatchError(msg) return state_in_redis @@ -1549,9 +1550,10 @@ def _get_state_from_cache(self, state_cls: type[T_STATE]) -> T_STATE: root_state = self._get_root_state() substate = root_state.get_substate(state_cls.get_full_name().split(".")) if not isinstance(substate, state_cls): - raise StateMismatchError( + msg = ( f"Searched for state {state_cls.get_full_name()} but found {substate}." ) + raise StateMismatchError(msg) return substate async def get_state(self, state_cls: type[T_STATE]) -> T_STATE: @@ -1601,9 +1603,8 @@ async def get_var_value(self, var: Var[VAR_TYPE]) -> VAR_TYPE: var_data = var._get_all_var_data() if var_data is None or not var_data.state: - raise UnretrievableVarValueError( - f"Unable to retrieve value for {var._js_expr}: not associated with any state." - ) + msg = f"Unable to retrieve value for {var._js_expr}: not associated with any state." + raise UnretrievableVarValueError(msg) # Fastish case: this var belongs to this state if var_data.state == self.get_full_name(): return getattr(self, var_data.field_name) @@ -1634,9 +1635,8 @@ def _get_event_handler( path, name = path[:-1], path[-1] substate = self.get_substate(path) if not substate: - raise ValueError( - "The value of state cannot be None when processing an event." - ) + msg = "The value of state cannot be None when processing an event." + raise ValueError(msg) handler = substate.event_handlers[name] # For background tasks, proxy the state @@ -1702,10 +1702,11 @@ def _is_valid_type(events: Any) -> bool: "ignore", message=f"coroutine '{coroutine_name}' was never awaited" ) - raise TypeError( + msg = ( f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (i.e. using `type(self)` or other class references)." f" Returned events of types {', '.join(map(str, map(type, events)))!s}." ) + raise TypeError(msg) async def _as_state_update( self, @@ -1834,9 +1835,8 @@ async def _process_event( try: payload[arg] = hinted_args(value) except ValueError: - raise ValueError( - f"Received a string value ({value}) for {arg} but expected a {hinted_args}" - ) from None + msg = f"Received a string value ({value}) for {arg} but expected a {hinted_args}" + raise ValueError(msg) from None else: console.warn( f"Received a string value ({value}) for {arg} but expected a {hinted_args}. A simple conversion was successful." @@ -2123,9 +2123,8 @@ async def __aenter__(self) -> BaseState: Raises: TypeError: always, because async contextmanager protocol is only supported for background task. """ - raise TypeError( - "Only background task should use `async with self` to modify state." - ) + msg = "Only background task should use `async with self` to modify state." + raise TypeError(msg) async def __aexit__(self, *exc_info: Any) -> None: """Exit the async context manager protocol. @@ -2136,7 +2135,6 @@ async def __aexit__(self, *exc_info: Any) -> None: Args: exc_info: The exception info tuple. """ - pass def __getstate__(self): """Get the state for redis serialization. @@ -2298,9 +2296,10 @@ def _deserialize( elif fp is not None and data is None: (substate_schema, state) = pickle.load(fp) else: - raise ValueError("Only one of `data` or `fp` must be provided") + msg = "Only one of `data` or `fp` must be provided" + raise ValueError(msg) if substate_schema != state._to_schema(): - raise StateSchemaMismatchError() + raise StateSchemaMismatchError return state @@ -2377,14 +2376,12 @@ def dynamic(func: Callable[[T], Component]): values = list(func_signature.values()) if number_of_parameters != 1: - raise DynamicComponentInvalidSignatureError( - "The function must have exactly one parameter, which is the state class." - ) + msg = "The function must have exactly one parameter, which is the state class." + raise DynamicComponentInvalidSignatureError(msg) if len(values) != 1: - raise DynamicComponentInvalidSignatureError( - "You must provide a type hint for the state class in the function." - ) + msg = "You must provide a type hint for the state class in the function." + raise DynamicComponentInvalidSignatureError(msg) state_class: type[T] = values[0] @@ -2457,7 +2454,7 @@ def on_load_internal(self) -> list[Event | EventSpec | event.EventCallback] | No ) if not load_events: self.is_hydrated = True - return # Fast path for navigation with no on_load events defined. + return None # Fast path for navigation with no on_load events defined. self.is_hydrated = False return [ *fix_events( @@ -2549,9 +2546,8 @@ def get_component(cls, *children, **props) -> Component: Raises: NotImplementedError: if the subclass does not override this method. """ - raise NotImplementedError( - f"{cls.__name__} must implement get_component to return the component instance." - ) + msg = f"{cls.__name__} must implement get_component to return the component instance." + raise NotImplementedError(msg) @classmethod def create(cls, *children, **props) -> Component: diff --git a/reflex/style.py b/reflex/style.py index d5cb8ffa7bd..5be7473d762 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -121,10 +121,11 @@ def convert_item( ReflexError: If an EventHandler is used as a style value """ if isinstance(style_item, EventHandler): - raise ReflexError( + msg = ( "EventHandlers cannot be used as style values. " "Please use a Var or a literal value." ) + raise ReflexError(msg) if isinstance(style_item, Var): return style_item, style_item._get_all_var_data() @@ -381,6 +382,7 @@ def format_as_emotion(style_dict: dict[str, Any]) -> Style | None: if _var_data is not None: emotion_style._var_data = VarData.merge(emotion_style._var_data, _var_data) return emotion_style + return None def convert_dict_to_style_and_format_emotion( diff --git a/reflex/testing.py b/reflex/testing.py index b4a1d39890b..ac8dd0bbf04 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -155,9 +155,8 @@ def create( app_name = f"{func_name}_{slug_suffix}" app_name = re.sub(r"[^a-zA-Z0-9_]", "_", app_name) elif isinstance(app_source, str): - raise ValueError( - "app_name must be provided when app_source is a string." - ) + msg = "app_name must be provided when app_source is a string." + raise ValueError(msg) else: app_name = app_source.__name__ @@ -285,7 +284,8 @@ def _initialize_app(self): self.app_instance._state_manager, StateManagerRedis ): if self.app_instance._state is None: - raise RuntimeError("State is not set.") + msg = "State is not set." + raise RuntimeError(msg) # Create our own redis connection for testing. self.state_manager = StateManagerRedis.create(self.app_instance._state) else: @@ -299,7 +299,8 @@ def _reload_state_module(self): def _get_backend_shutdown_handler(self): if self.backend is None: - raise RuntimeError("Backend was not initialized.") + msg = "Backend was not initialized." + raise RuntimeError(msg) original_shutdown = self.backend.shutdown @@ -330,7 +331,8 @@ async def _shutdown(*args, **kwargs) -> None: def _start_backend(self, port: int = 0): if self.app_asgi is None: - raise RuntimeError("App was not initialized.") + msg = "App was not initialized." + raise RuntimeError(msg) self.backend = uvicorn.Server( uvicorn.Config( app=self.app_asgi, @@ -366,7 +368,8 @@ async def _reset_backend_state_manager(self): state=self.app_instance._state, ) if not isinstance(self.app_instance.state_manager, StateManagerRedis): - raise RuntimeError("Failed to reset state manager.") + msg = "Failed to reset state manager." + raise RuntimeError(msg) def _start_frontend(self): # Set up the frontend. @@ -406,7 +409,8 @@ def _wait_frontend(self): config.deploy_url = self.frontend_url break if self.frontend_url is None: - raise RuntimeError("Frontend did not start") + msg = "Frontend did not start" + raise RuntimeError(msg) def consume_frontend_output(): while True: @@ -578,20 +582,23 @@ def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket: TimeoutError: when server or sockets are not ready """ if self.backend is None: - raise RuntimeError("Backend is not running.") + msg = "Backend is not running." + raise RuntimeError(msg) backend = self.backend # check for servers to be initialized if not self._poll_for( target=lambda: getattr(backend, "servers", False), timeout=timeout, ): - raise TimeoutError("Backend servers are not initialized.") + msg = "Backend servers are not initialized." + raise TimeoutError(msg) # check for sockets to be listening if not self._poll_for( target=lambda: getattr(backend.servers[0], "sockets", False), timeout=timeout, ): - raise TimeoutError("Backend is not listening.") + msg = "Backend is not listening." + raise TimeoutError(msg) return backend.servers[0].sockets[0] def frontend( @@ -619,12 +626,14 @@ def frontend( RuntimeError: when selenium is not importable or frontend is not running """ if not has_selenium: - raise RuntimeError( + msg = ( "Frontend functionality requires `selenium` to be installed, " "and it could not be imported." ) + raise RuntimeError(msg) if self.frontend_url is None: - raise RuntimeError("Frontend is not running.") + msg = "Frontend is not running." + raise RuntimeError(msg) want_headless = False if environment.APP_HARNESS_HEADLESS.get(): want_headless = True @@ -650,7 +659,8 @@ def frontend( if want_headless: driver_options.add_argument("headless") if driver_options is None: - raise RuntimeError(f"Could not determine options for {driver_clz}") + msg = f"Could not determine options for {driver_clz}" + raise RuntimeError(msg) if args := environment.APP_HARNESS_DRIVER_ARGS.get(): for arg in args.split(","): driver_options.add_argument(arg) @@ -680,7 +690,8 @@ async def get_state(self, token: str) -> BaseState: RuntimeError: when the app hasn't started running """ if self.state_manager is None: - raise RuntimeError("state_manager is not set.") + msg = "state_manager is not set." + raise RuntimeError(msg) try: return await self.state_manager.get_state(token) finally: @@ -698,7 +709,8 @@ async def set_state(self, token: str, **kwargs) -> None: RuntimeError: when the app hasn't started running """ if self.state_manager is None: - raise RuntimeError("state_manager is not set.") + msg = "state_manager is not set." + raise RuntimeError(msg) state = await self.get_state(token) for key, value in kwargs.items(): setattr(state, key, value) @@ -722,9 +734,11 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: RuntimeError: when the app hasn't started running """ if self.state_manager is None: - raise RuntimeError("state_manager is not set.") + msg = "state_manager is not set." + raise RuntimeError(msg) if self.app_instance is None: - raise RuntimeError("App is not running.") + msg = "App is not running." + raise RuntimeError(msg) app_state_manager = self.app_instance.state_manager if isinstance(self.state_manager, StateManagerRedis): # Temporarily replace the app's state manager with our own, since @@ -761,9 +775,8 @@ def poll_for_content( target=lambda: element.text != exp_not_equal, timeout=timeout, ): - raise TimeoutError( - f"{element} content remains {exp_not_equal!r} while polling.", - ) + msg = f"{element} content remains {exp_not_equal!r} while polling." + raise TimeoutError(msg) return element.text def poll_for_value( @@ -792,9 +805,8 @@ def poll_for_value( target=lambda: element.get_attribute("value") not in exp_not_equal, timeout=timeout, ): - raise TimeoutError( - f"{element} content remains {exp_not_equal!r} while polling.", - ) + msg = f"{element} content remains {exp_not_equal!r} while polling." + raise TimeoutError(msg) return element.get_attribute("value") def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]: @@ -812,15 +824,18 @@ def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]: ValueError: when the state_manager is not a memory state manager """ if self.app_instance is None: - raise RuntimeError("App is not running.") + msg = "App is not running." + raise RuntimeError(msg) state_manager = self.app_instance.state_manager if not isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): - raise ValueError("Only works with memory or disk state manager") + msg = "Only works with memory or disk state manager" + raise ValueError(msg) if not self._poll_for( target=lambda: state_manager.states, timeout=timeout, ): - raise TimeoutError("No states were observed while polling.") + msg = "No states were observed while polling." + raise TimeoutError(msg) return state_manager.states @@ -962,11 +977,13 @@ def _start_frontend(self): def _wait_frontend(self): self._poll_for(lambda: self.frontend_server is not None) if self.frontend_server is None or not self.frontend_server.socket.fileno(): - raise RuntimeError("Frontend did not start") + msg = "Frontend did not start" + raise RuntimeError(msg) def _start_backend(self): if self.app_asgi is None: - raise RuntimeError("App was not initialized.") + msg = "App was not initialized." + raise RuntimeError(msg) environment.REFLEX_SKIP_COMPILE.set(True) self.backend = uvicorn.Server( uvicorn.Config( diff --git a/reflex/utils/build.py b/reflex/utils/build.py index c6d946d8232..a2a55f914a0 100644 --- a/reflex/utils/build.py +++ b/reflex/utils/build.py @@ -132,7 +132,7 @@ def _zip( def zip_app( frontend: bool = True, backend: bool = True, - zip_dest_dir: str | Path = Path.cwd(), + zip_dest_dir: str | Path | None = None, upload_db_file: bool = False, ): """Zip up the app. @@ -143,6 +143,7 @@ def zip_app( zip_dest_dir: The directory to export the zip file to. upload_db_file: Whether to upload the database file. """ + zip_dest_dir = zip_dest_dir or Path.cwd() zip_dest_dir = Path(zip_dest_dir) files_to_exclude = { constants.ComponentName.FRONTEND.zip(), diff --git a/reflex/utils/console.py b/reflex/utils/console.py index 40cb8a1b2c2..1167a44ace5 100644 --- a/reflex/utils/console.py +++ b/reflex/utils/console.py @@ -59,9 +59,8 @@ def set_log_level(log_level: LogLevel | None): if log_level is None: return if not isinstance(log_level, LogLevel): - raise TypeError( - f"log_level must be a LogLevel enum value, got {log_level} of type {type(log_level)} instead." - ) + msg = f"log_level must be a LogLevel enum value, got {log_level} of type {type(log_level)} instead." + raise TypeError(msg) global _LOG_LEVEL if log_level != _LOG_LEVEL: # Set the loglevel persistenly for subprocesses. @@ -89,8 +88,7 @@ def print(msg: str, dedupe: bool = False, **kwargs): if dedupe: if msg in _EMITTED_PRINTS: return - else: - _EMITTED_PRINTS.add(msg) + _EMITTED_PRINTS.add(msg) _console.print(msg, **kwargs) @@ -107,8 +105,7 @@ def debug(msg: str, dedupe: bool = False, **kwargs): if dedupe: if msg_ in _EMITTED_DEBUG: return - else: - _EMITTED_DEBUG.add(msg_) + _EMITTED_DEBUG.add(msg_) if progress := kwargs.pop("progress", None): progress.console.print(msg_, **kwargs) else: @@ -127,8 +124,7 @@ def info(msg: str, dedupe: bool = False, **kwargs): if dedupe: if msg in _EMITTED_INFO: return - else: - _EMITTED_INFO.add(msg) + _EMITTED_INFO.add(msg) print(f"[cyan]Info: {msg}[/cyan]", **kwargs) @@ -144,8 +140,7 @@ def success(msg: str, dedupe: bool = False, **kwargs): if dedupe: if msg in _EMITTED_SUCCESS: return - else: - _EMITTED_SUCCESS.add(msg) + _EMITTED_SUCCESS.add(msg) print(f"[green]Success: {msg}[/green]", **kwargs) @@ -161,8 +156,7 @@ def log(msg: str, dedupe: bool = False, **kwargs): if dedupe: if msg in _EMITTED_LOGS: return - else: - _EMITTED_LOGS.add(msg) + _EMITTED_LOGS.add(msg) _console.log(msg, **kwargs) @@ -188,8 +182,7 @@ def warn(msg: str, dedupe: bool = False, **kwargs): if dedupe: if msg in _EMIITED_WARNINGS: return - else: - _EMIITED_WARNINGS.add(msg) + _EMIITED_WARNINGS.add(msg) print(f"[orange1]Warning: {msg}[/orange1]", **kwargs) @@ -271,8 +264,7 @@ def error(msg: str, dedupe: bool = False, **kwargs): if dedupe: if msg in _EMITTED_ERRORS: return - else: - _EMITTED_ERRORS.add(msg) + _EMITTED_ERRORS.add(msg) print(f"[red]{msg}[/red]", **kwargs) diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index 266a090113a..67b9fbf1de7 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -244,14 +244,12 @@ def get_app_file() -> Path: sys.path.insert(0, current_working_dir) module_spec = importlib.util.find_spec(get_app_module()) if module_spec is None: - raise ImportError( - f"Module {get_app_module()} not found. Make sure the module is installed." - ) + msg = f"Module {get_app_module()} not found. Make sure the module is installed." + raise ImportError(msg) file_name = module_spec.origin if file_name is None: - raise ImportError( - f"Module {get_app_module()} not found. Make sure the module is installed." - ) + msg = f"Module {get_app_module()} not found. Make sure the module is installed." + raise ImportError(msg) return Path(file_name).resolve() diff --git a/reflex/utils/format.py b/reflex/utils/format.py index c5b04d5eb30..22c1347098f 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -73,7 +73,8 @@ def get_close_char(open: str, close: str | None = None) -> str: if close is not None: return close if open not in WRAP_MAP: - raise ValueError(f"Invalid wrap open: {open}, must be one of {WRAP_MAP.keys()}") + msg = f"Invalid wrap open: {open}, must be one of {WRAP_MAP.keys()}" + raise ValueError(msg) return WRAP_MAP[open] @@ -187,8 +188,7 @@ def to_camel_case(text: str, treat_hyphens_as_underscores: bool = True) -> str: # Capitalize the first letter of each word except the first one if len(words) == 1: return words[0] - converted_word = words[0] + "".join([w.capitalize() for w in words[1:]]) - return converted_word + return words[0] + "".join([w.capitalize() for w in words[1:]]) def to_title_case(text: str, sep: str = "") -> str: @@ -264,17 +264,13 @@ def escape_outside_segments(segment: str): if segment.startswith("${") and segment.endswith("}"): # Return the `${}` segment unchanged return segment - else: - # Escape backticks in the segment - segment = segment.replace(r"\`", "`") - segment = segment.replace("`", r"\`") - return segment + # Escape backticks in the segment + return segment.replace(r"\`", "`").replace("`", r"\`") # Split the string into parts, keeping the `${}` segments parts = re.split(r"(\$\{.*?\})", string) escaped_parts = [escape_outside_segments(part) for part in parts] - escaped_string = "".join(escaped_parts) - return escaped_string + return "".join(escaped_parts) def _wrap_js_string(string: str) -> str: @@ -287,8 +283,7 @@ def _wrap_js_string(string: str) -> str: The wrapped string. """ string = wrap(string, "`") - string = wrap(string, "{") - return string + return wrap(string, "{") def format_string(string: str) -> str: @@ -402,13 +397,13 @@ def format_prop( return str(Var.create(prop)) # Handle other types. - elif isinstance(prop, str): + if isinstance(prop, str): if is_wrapped(prop, "{"): return prop return json_dumps(prop) # For dictionaries, convert any properties to strings. - elif isinstance(prop, dict): + if isinstance(prop, dict): prop = serializers.serialize_dict(prop) # pyright: ignore [reportAttributeAccessIssue] else: @@ -417,11 +412,13 @@ def format_prop( except exceptions.InvalidStylePropError: raise except TypeError as e: - raise TypeError(f"Could not format prop: {prop} of type {type(prop)}") from e + msg = f"Could not format prop: {prop} of type {type(prop)}" + raise TypeError(msg) from e # Wrap the variable in braces. if not isinstance(prop, str): - raise ValueError(f"Invalid prop: {prop}. Expected a string.") + msg = f"Invalid prop: {prop}. Expected a string." + raise ValueError(msg) return wrap(prop, "{", check_first=False) @@ -591,9 +588,8 @@ def _default_args_spec(): elif isinstance(spec, type(lambda: None)): specs = call_event_fn(spec, args_spec or _default_args_spec) # pyright: ignore [reportAssignmentType, reportArgumentType] if isinstance(specs, Var): - raise ValueError( - f"Invalid event spec: {specs}. Expected a list of EventSpecs." - ) + msg = f"Invalid event spec: {specs}. Expected a list of EventSpecs." + raise ValueError(msg) payloads.extend(format_event(s) for s in specs) # Return the final code snippet, expecting queueEvents, processEvent, and socket to be in scope. @@ -662,12 +658,14 @@ def format_library_name(library_fullname: str | dict[str, Any]) -> str: # If input is a dictionary, extract the 'name' key if isinstance(library_fullname, dict): if "name" not in library_fullname: - raise KeyError("Dictionary input must contain a 'name' key") + msg = "Dictionary input must contain a 'name' key" + raise KeyError(msg) library_fullname = library_fullname["name"] # Process the library name as a string if not isinstance(library_fullname, str): - raise TypeError("Library name must be a string") + msg = "Library name must be a string" + raise TypeError(msg) if library_fullname.startswith("https://"): return library_fullname @@ -764,9 +762,8 @@ def format_data_editor_column(col: str | dict): if isinstance(col, Var): return col - raise ValueError( - f"unexpected type ({(type(col).__name__)}: {col}) for column header in data_editor" - ) + msg = f"unexpected type ({(type(col).__name__)}: {col}) for column header in data_editor" + raise ValueError(msg) def format_data_editor_cell(cell: Any): diff --git a/reflex/utils/imports.py b/reflex/utils/imports.py index c1cb0098fb4..e98bfb17bc3 100644 --- a/reflex/utils/imports.py +++ b/reflex/utils/imports.py @@ -127,10 +127,11 @@ def name(self) -> str: """ if self.alias: return ( - self.alias if self.is_default else " as ".join([self.tag, self.alias]) # pyright: ignore [reportCallIssue,reportArgumentType] + self.alias + if self.is_default + else (self.tag + " as " + self.alias if self.tag else self.alias) ) - else: - return self.tag or "" + return self.tag or "" ImportTypes = str | ImportVar | list[str | ImportVar] | list[ImportVar] diff --git a/reflex/utils/lazy_loader.py b/reflex/utils/lazy_loader.py index a4f887e7562..e2267566a39 100644 --- a/reflex/utils/lazy_loader.py +++ b/reflex/utils/lazy_loader.py @@ -65,7 +65,7 @@ def attach( def __getattr__(name: str): # noqa: N807 if name in submodules: return importlib.import_module(f"{package_name}.{name}") - elif name in attr_to_modules: + if name in attr_to_modules: submod_path = f"{package_name}.{attr_to_modules[name]}" submod = importlib.import_module(submod_path) attr = getattr(submod, name) @@ -78,8 +78,8 @@ def __getattr__(name: str): # noqa: N807 pkg.__dict__[name] = attr return attr - else: - raise AttributeError(f"No {package_name} attribute {name}") + msg = f"No {package_name} attribute {name}" + raise AttributeError(msg) def __dir__(): # noqa: N807 return __all__ diff --git a/reflex/utils/misc.py b/reflex/utils/misc.py index 1c5f948aa5f..ba8ada43a62 100644 --- a/reflex/utils/misc.py +++ b/reflex/utils/misc.py @@ -20,5 +20,6 @@ async def run_in_thread(func: Callable) -> Any: Any: The return value of the function. """ if asyncio.coroutines.iscoroutinefunction(func): - raise ValueError("func must be a non-async function") + msg = "func must be a non-async function" + raise ValueError(msg) return await asyncio.get_event_loop().run_in_executor(None, func) diff --git a/reflex/utils/net.py b/reflex/utils/net.py index 8202e483284..f90190867e6 100644 --- a/reflex/utils/net.py +++ b/reflex/utils/net.py @@ -19,7 +19,7 @@ def _httpx_verify_kwarg() -> bool: Returns: True if SSL verification is enabled, False otherwise """ - from ..config import environment + from reflex.config import environment return not environment.SSL_NO_VERIFY.get() @@ -131,7 +131,7 @@ def _httpx_local_address_kwarg() -> str: Returns: The local address to bind to """ - from ..config import environment + from reflex.config import environment return environment.REFLEX_HTTP_CLIENT_BIND_ADDRESS.get() or ( "::" if _should_use_ipv6() else "0.0.0.0" diff --git a/reflex/utils/path_ops.py b/reflex/utils/path_ops.py index 9588497aa38..c8cd8315fcc 100644 --- a/reflex/utils/path_ops.py +++ b/reflex/utils/path_ops.py @@ -300,7 +300,8 @@ def update_directory_tree(src: Path, dest: Path): ValueError: If the source is not a directory """ if not src.is_dir(): - raise ValueError(f"Source {src} is not a directory") + msg = f"Source {src} is not a directory" + raise ValueError(msg) # Ensure the destination directory exists dest.mkdir(parents=True, exist_ok=True) diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 8277c2006f2..3f0a50d4931 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -129,7 +129,6 @@ def check_latest_package_version(package_name: str): ) except Exception: console.debug(f"Failed to check for the latest version of {package_name}.") - pass def get_or_set_last_reflex_version_check_datetime(): @@ -255,9 +254,8 @@ def get_nodejs_compatible_package_managers( package_managers = list(filter(None, package_managers)) if not package_managers and raise_on_none: - raise FileNotFoundError( - "Bun or npm not found. You might need to rerun `reflex init` or install either." - ) + msg = "Bun or npm not found. You might need to rerun `reflex init` or install either." + raise FileNotFoundError(msg) return package_managers @@ -310,9 +308,8 @@ def get_js_package_executor(raise_on_none: bool = False) -> Sequence[Sequence[st package_managers = list(filter(None, package_managers)) if not package_managers and raise_on_none: - raise FileNotFoundError( - "Bun or npm not found. You might need to rerun `reflex init` or install either." - ) + msg = "Bun or npm not found. You might need to rerun `reflex init` or install either." + raise FileNotFoundError(msg) return package_managers @@ -345,10 +342,11 @@ def _check_app_name(config: Config): RuntimeError: If the app name is not set in the config. """ if not config.app_name: - raise RuntimeError( + msg = ( "Cannot get the app module because `app_name` is not set in rxconfig! " "If this error occurs in a reflex test case, ensure that `get_app` is mocked." ) + raise RuntimeError(msg) def get_app(reload: bool = False) -> ModuleType: @@ -412,9 +410,8 @@ def get_and_validate_app(reload: bool = False) -> AppInfo: app_module = get_app(reload=reload) app = getattr(app_module, constants.CompileVars.APP) if not isinstance(app, App): - raise RuntimeError( - "The app instance in the specified app_module_import in rxconfig must be an instance of rx.App." - ) + msg = "The app instance in the specified app_module_import in rxconfig must be an instance of rx.App." + raise RuntimeError(msg) return AppInfo(app=app, module=app_module) @@ -575,9 +572,8 @@ def parse_redis_url() -> str | None: if not config.redis_url: return None if not config.redis_url.startswith(("redis://", "rediss://", "unix://")): - raise ValueError( - "REDIS_URL must start with 'redis://', 'rediss://', or 'unix://'." - ) + msg = "REDIS_URL must start with 'redis://', 'rediss://', or 'unix://'." + raise ValueError(msg) return config.redis_url @@ -1157,15 +1153,16 @@ def download_and_run(url: str, *args, show_status: bool = False, **env): raise click.exceptions.Exit(1) from None # Save the script to a temporary file. - script = Path(tempfile.NamedTemporaryFile().name) + with tempfile.NamedTemporaryFile() as tempfile_file: + script = Path(tempfile_file.name) - script.write_text(response.text) + script.write_text(response.text) - # Run the script. - env = {**os.environ, **env} - process = processes.new_process(["bash", str(script), *args], env=env) - show = processes.show_status if show_status else processes.show_logs - show(f"Installing {url}", process) + # Run the script. + env = {**os.environ, **env} + process = processes.new_process(["bash", str(script), *args], env=env) + show = processes.show_status if show_status else processes.show_logs + show(f"Installing {url}", process) def install_bun(): @@ -1213,7 +1210,8 @@ def install_bun(): ) else: if path_ops.which("unzip") is None: - raise SystemPackageMissingError("unzip") + msg = "unzip" + raise SystemPackageMissingError(msg) # Run the bun install script. download_and_run( @@ -1262,13 +1260,15 @@ def cached_procedure( ValueError: If both cache_file and cache_file_fn are provided. """ if cache_file and cache_file_fn is not None: - raise ValueError("cache_file and cache_file_fn cannot both be provided.") + msg = "cache_file and cache_file_fn cannot both be provided." + raise ValueError(msg) def _inner_decorator(func: Callable): def _inner(*args, **kwargs): _cache_file = cache_file_fn() if cache_file_fn is not None else cache_file if not _cache_file: - raise ValueError("Unknown cache file, cannot cache result.") + msg = "Unknown cache file, cannot cache result." + raise ValueError(msg) payload = _read_cached_procedure_file(_cache_file) new_payload = payload_fn(*args, **kwargs) if payload != new_payload: @@ -1446,7 +1446,7 @@ def validate_bun(bun_path: Path | None = None): "Failed to obtain bun version. Make sure the specified bun path in your config is correct." ) raise click.exceptions.Exit(1) - elif bun_version < version.parse(constants.Bun.MIN_VERSION): + if bun_version < version.parse(constants.Bun.MIN_VERSION): console.warn( f"Reflex requires bun version {constants.Bun.MIN_VERSION} or higher to run, but the detected version is " f"{bun_version}. If you have specified a custom bun path in your config, make sure to provide one " @@ -1657,8 +1657,7 @@ def get_release_by_tag(tag: str) -> dict | None: if asset is None: console.warn(f"Templates metadata not found for version {version}") return {} - else: - templates_url = asset["browser_download_url"] + templates_url = asset["browser_download_url"] templates_data = net.get(templates_url, follow_redirects=True).json()["templates"] @@ -1864,7 +1863,7 @@ def initialize_app(app_name: str, template: str | None = None) -> str | None: # Check if the app is already initialized. if constants.Config.FILE.exists(): telemetry.send("reinit") - return + return None templates: dict[str, Template] = {} diff --git a/reflex/utils/processes.py b/reflex/utils/processes.py index c8eadd69f23..24bc6099846 100644 --- a/reflex/utils/processes.py +++ b/reflex/utils/processes.py @@ -137,11 +137,10 @@ def handle_port(service_name: str, port: int, auto_increment: bool) -> int: return port if auto_increment: return change_port(port, service_name) - else: - console.error( - f"{service_name.capitalize()} port: {port} is already in use by PID: {process.pid}." - ) - raise click.exceptions.Exit() + console.error( + f"{service_name.capitalize()} port: {port} is already in use by PID: {process.pid}." + ) + raise click.exceptions.Exit @overload @@ -316,7 +315,6 @@ def stream_logs( # But if the process is still running that is weird. raise # If the process exited, break out of the loop for post processing. - pass # Check if the process failed (not printing the logs for SIGINT). diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index b5c1ff78ecd..5598af8f16f 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -159,9 +159,8 @@ def _get_type_hint( res_args.sort() if len(res_args) == 1: return f"{res_args[0]} | None" - else: - res = f"{' | '.join(res_args)}" - return f"{res} | None" + res = f"{' | '.join(res_args)}" + return f"{res} | None" res_args = [ _get_type_hint(arg, type_hint_globals, rx_types.is_optional(arg)) @@ -185,10 +184,11 @@ def _get_type_hint( value.__module__ not in ["builtins", "__builtins__"] and value.__name__ not in type_hint_globals ): - raise TypeError( + msg = ( f"{value.__module__ + '.' + value.__name__} is not a default import, " "add it to DEFAULT_IMPORTS in pyi_generator.py" ) + raise TypeError(msg) res = f"{value.__name__}[{', '.join(inner_container_type_args)}]" @@ -447,7 +447,7 @@ def type_to_ast(typ: Any, cls: type) -> ast.expr: return ast.Name(id=typ.__module__ + "." + typ.__name__) return ast.Name(id=typ.__name__) - elif hasattr(typ, "_name"): + if hasattr(typ, "_name"): return ast.Name(id=typ._name) return ast.Name(id=str(typ)) @@ -512,7 +512,8 @@ def _generate_component_create_functiondef( TypeError: If clz is not a subclass of Component. """ if not issubclass(clz, Component): - raise TypeError(f"clz must be a subclass of Component, not {clz!r}") + msg = f"clz must be a subclass of Component, not {clz!r}" + raise TypeError(msg) # add the imports needed by get_type_hint later type_hint_globals.update( @@ -656,7 +657,7 @@ def figure_out_return_type(annotation: Any): defaults=[], ) - definition = ast.FunctionDef( # pyright: ignore [reportCallIssue] + return ast.FunctionDef( # pyright: ignore [reportCallIssue] name="create", args=create_args, body=[ @@ -678,7 +679,6 @@ def figure_out_return_type(annotation: Any): lineno=lineno, returns=ast.Constant(value=clz.__name__), ) - return definition def _generate_staticmethod_call_functiondef( @@ -712,7 +712,7 @@ def _generate_staticmethod_call_functiondef( else [] ), ) - definition = ast.FunctionDef( # pyright: ignore [reportCallIssue] + return ast.FunctionDef( # pyright: ignore [reportCallIssue] name="__call__", args=call_args, body=[ @@ -731,7 +731,6 @@ def _generate_staticmethod_call_functiondef( ) ), ) - return definition def _generate_namespace_call_functiondef( @@ -843,6 +842,7 @@ def _current_class_is_component(self) -> type[Component] | None: and issubclass((clz := self.classes[self.current_class]), Component) ): return clz + return None def visit_Module(self, node: ast.Module) -> ast.Module: """Visit a Module node and remove docstring from body. @@ -1023,7 +1023,7 @@ def visit_Assign(self, node: ast.Assign) -> ast.Assign | None: if isinstance(target, ast.Tuple): for name in target.elts: if isinstance(name, ast.Name) and name.id.startswith("_"): - return + return None return node @@ -1109,7 +1109,7 @@ def _get_init_lazy_imports(self, mod: tuple | ModuleType, new_tree: ast.AST): pyright_ignore_imports = getattr(mod, "_PYRIGHT_IGNORE_IMPORTS", []) if not sub_mods and not sub_mod_attrs: - return + return None sub_mods_imports = [] sub_mod_attrs_imports = [] @@ -1164,7 +1164,7 @@ def _scan_file(self, module_path: Path) -> tuple[str, str] | None: } is_init_file = _relative_to_pwd(module_path).name == "__init__.py" if not class_names and not is_init_file: - return + return None if is_init_file: new_tree = InitStubGenerator(module, class_names).visit( @@ -1172,7 +1172,7 @@ def _scan_file(self, module_path: Path) -> tuple[str, str] | None: ) init_imports = self._get_init_lazy_imports(module, new_tree) if not init_imports: - return + return None content_hash = self._write_pyi_file(module_path, init_imports) else: new_tree = StubGenerator(module, class_names).visit( diff --git a/reflex/utils/redir.py b/reflex/utils/redir.py index 243e9153eed..0decef740b9 100644 --- a/reflex/utils/redir.py +++ b/reflex/utils/redir.py @@ -5,9 +5,9 @@ import httpx +from reflex import constants from reflex.utils import net -from .. import constants from . import console diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index 8682dde3663..6304018629a 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -76,7 +76,8 @@ def wrapper(fn: SERIALIZED_FUNCTION) -> SERIALIZED_FUNCTION: # Make sure the function takes a single argument. if len(args) != 1: - raise ValueError("Serializer must take a single argument.") + msg = "Serializer must take a single argument." + raise ValueError(msg) # Get the type of the argument. type_ = type_hints[args[0]] @@ -166,8 +167,7 @@ def serialize( # Return the serialized value and the type. if get_type: return serialized, get_serializer_type(type(value)) - else: - return serialized + return serialized @functools.lru_cache @@ -427,7 +427,7 @@ def format_dataframe_values(df: DataFrame) -> list[list[Any]]: """ return [ [str(d) if isinstance(d, (list, tuple)) else d for d in data] - for data in list(df.values.tolist()) + for data in list(df.to_numpy().tolist()) ] @serializer diff --git a/reflex/utils/telemetry.py b/reflex/utils/telemetry.py index d5e36bd01c3..7c2817b9fad 100644 --- a/reflex/utils/telemetry.py +++ b/reflex/utils/telemetry.py @@ -232,6 +232,9 @@ def _send(event: str, telemetry_enabled: bool | None, **kwargs) -> bool: return False +background_tasks = set() + + def send(event: str, telemetry_enabled: bool | None = None, **kwargs): """Send anonymous telemetry for Reflex. @@ -246,7 +249,9 @@ async def async_send(event: str, telemetry_enabled: bool | None, **kwargs): try: # Within an event loop context, send the event asynchronously. - asyncio.create_task(async_send(event, telemetry_enabled, **kwargs)) + task = asyncio.create_task(async_send(event, telemetry_enabled, **kwargs)) + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) except RuntimeError: # If there is no event loop, send the event synchronously. warnings.filterwarnings("ignore", category=RuntimeWarning) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index d9b79b17ac8..028c04ef46e 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -382,7 +382,7 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None if hasattr(cls, "__fields__") and name in cls.__fields__: # pydantic models return get_field_type(cls, name) - elif isinstance(cls, type) and issubclass(cls, DeclarativeBase): + if isinstance(cls, type) and issubclass(cls, DeclarativeBase): insp = sqlalchemy.inspect(cls) if name in insp.columns: # check for list types @@ -414,8 +414,7 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None if isinstance(prop, Relationship): type_ = prop.mapper.class_ # TODO: check for nullable? - type_ = list[type_] if prop.uselist else type_ | None - return type_ + return list[type_] if prop.uselist else type_ | None if isinstance(attr, AssociationProxyInstance): return list[ get_attribute_access_type( @@ -448,7 +447,6 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None return hints[name] except exceptions as e: console.warn(f"Failed to resolve ForwardRefs for {cls}.{name} due to {e}") - pass return None # Attribute is not accessible. @@ -469,7 +467,8 @@ def get_base_class(cls: GenericType) -> type: # only literals of the same type are supported. arg_type = type(get_args(cls)[0]) if not all(type(arg) is arg_type for arg in get_args(cls)): - raise TypeError("only literals of the same type are supported") + msg = "only literals of the same type are supported" + raise TypeError(msg) return type(get_args(cls)[0]) if is_union(cls): @@ -497,13 +496,13 @@ def _breakpoints_satisfies_typing(cls_check: GenericType, instance: Any) -> bool if not isinstance(value, str) or value not in get_args(expected_type): return False return True - elif isinstance(cls_check_base, tuple): + if isinstance(cls_check_base, tuple): # union type, so check all types return any( _breakpoints_satisfies_typing(type_to_check, instance) for type_to_check in get_args(cls_check) ) - elif cls_check_base == reflex.vars.Var and "__args__" in cls_check.__dict__: + if cls_check_base == reflex.vars.Var and "__args__" in cls_check.__dict__: return _breakpoints_satisfies_typing(get_args(cls_check)[0], instance) return False @@ -555,7 +554,8 @@ def _issubclass(cls: GenericType, cls_check: GenericType, instance: Any = None) except TypeError as te: # These errors typically arise from bad annotations and are hard to # debug without knowing the type that we tried to compare. - raise TypeError(f"Invalid type for issubclass: {cls_base}") from te + msg = f"Invalid type for issubclass: {cls_base}" + raise TypeError(msg) from te def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool: @@ -913,9 +913,8 @@ def validate_literal(key: str, value: Any, expected_type: type, comp_name: str): [str(v) if not isinstance(v, str) else f"'{v}'" for v in allowed_values] ) value_str = f"'{value}'" if isinstance(value, str) else value - raise ValueError( - f"prop value for {key!s} of the `{comp_name}` component should be one of the following: {allowed_value_str}. Got {value_str} instead" - ) + msg = f"prop value for {key!s} of the `{comp_name}` component should be one of the following: {allowed_value_str}. Got {value_str} instead" + raise ValueError(msg) def validate_parameter_literals(func: Callable): diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 60899ec4bba..d3e106f1a69 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -247,9 +247,8 @@ def merge(*all: VarData | None) -> VarData | None: ) if positions: if len(positions) > 1: - raise exceptions.ReflexError( - f"Cannot merge var data with different positions: {positions}" - ) + msg = f"Cannot merge var data with different positions: {positions}" + raise exceptions.ReflexError(msg) position = positions[0] else: position = None @@ -511,14 +510,12 @@ def __post_init__(self): TypeError: If _js_expr is not a string. """ if not isinstance(self._js_expr, str): - raise TypeError( - f"Expected _js_expr to be a string, got value {self._js_expr!r} of type {type(self._js_expr).__name__}" - ) + msg = f"Expected _js_expr to be a string, got value {self._js_expr!r} of type {type(self._js_expr).__name__}" + raise TypeError(msg) if self._var_data is not None and not isinstance(self._var_data, VarData): - raise TypeError( - f"Expected _var_data to be a VarData, got value {self._var_data!r} of type {type(self._var_data).__name__}" - ) + msg = f"Expected _var_data to be a VarData, got value {self._var_data!r} of type {type(self._var_data).__name__}" + raise TypeError(msg) # Decode any inline Var markup and apply it to the instance _var_data, _js_expr = _decode_var_immutable(self._js_expr) @@ -597,15 +594,16 @@ def _replace( TypeError: If _var_is_local, _var_is_string, or _var_full_name_needs_state_prefix is not None. """ if kwargs.get("_var_is_local", False) is not False: - raise TypeError("The _var_is_local argument is not supported for Var.") + msg = "The _var_is_local argument is not supported for Var." + raise TypeError(msg) if kwargs.get("_var_is_string", False) is not False: - raise TypeError("The _var_is_string argument is not supported for Var.") + msg = "The _var_is_string argument is not supported for Var." + raise TypeError(msg) if kwargs.get("_var_full_name_needs_state_prefix", False) is not False: - raise TypeError( - "The _var_full_name_needs_state_prefix argument is not supported for Var." - ) + msg = "The _var_full_name_needs_state_prefix argument is not supported for Var." + raise TypeError(msg) value_with_replaced = dataclasses.replace( self, _var_type=_var_type or self._var_type, @@ -850,10 +848,9 @@ def to( new_var_type = var_type else: new_var_type = var_type or current_var_type - to_operation_return = var_subclass.to_var_subclass.create( + return var_subclass.to_var_subclass.create( # pyright: ignore [reportReturnType] value=self, _var_type=new_var_type ) - return to_operation_return # pyright: ignore [reportReturnType] # If we can't determine the first argument, we just replace the _var_type. if not safe_issubclass(output, Var) or var_type is None: @@ -940,7 +937,8 @@ def guess_type(self) -> Var: fixed_type = unionize(*(type(arg) for arg in args)) if not inspect.isclass(fixed_type): - raise TypeError(f"Unsupported type {var_type} for guess_type.") + msg = f"Unsupported type {var_type} for guess_type." + raise TypeError(msg) if fixed_type is None: return self.to(None) @@ -992,9 +990,8 @@ def _get_default_value(self) -> Any: return pd.DataFrame() except ImportError as e: - raise ImportError( - "Please install pandas to use dataframes in your app." - ) from e + msg = "Please install pandas to use dataframes in your app." + raise ImportError(msg) from e return set() if safe_issubclass(type_, set) else None def _get_setter_name(self, include_state: bool = True) -> str: @@ -1012,7 +1009,7 @@ def _get_setter_name(self, include_state: bool = True) -> str: return setter if not include_state or var_data.state == "": return setter - return ".".join((var_data.state, setter)) + return var_data.state + "." + setter def _get_setter(self) -> Callable[[BaseState, Any], None]: """Get the var's setter function. @@ -1348,9 +1345,8 @@ def __getitem__(self, key: Any) -> Var: self, f"access the item '{key}'", ) - raise TypeError( - f"Var of type {self._var_type} does not support item access." - ) + msg = f"Var of type {self._var_type} does not support item access." + raise TypeError(msg) def __getattr__(self, name: str): """Get an attribute of the var. @@ -1366,14 +1362,15 @@ def __getattr__(self, name: str): # noqa: DAR101 self """ if name.startswith("_"): - raise VarAttributeError(f"Attribute {name} not found.") + msg = f"Attribute {name} not found." + raise VarAttributeError(msg) if name == "contains": - raise TypeError( - f"Var of type {self._var_type} does not support contains check." - ) + msg = f"Var of type {self._var_type} does not support contains check." + raise TypeError(msg) if name == "reverse": - raise TypeError("Cannot reverse non-list var.") + msg = "Cannot reverse non-list var." + raise TypeError(msg) if self._var_type is Any: raise exceptions.UntypedVarError( @@ -1381,9 +1378,8 @@ def __getattr__(self, name: str): f"access the attribute '{name}'", ) - raise VarAttributeError( - f"The State var {escape(self._js_expr)} of type {escape(str(self._var_type))} has no attribute '{name}' or may have been annotated wrongly.", - ) + msg = f"The State var {escape(self._js_expr)} of type {escape(str(self._var_type))} has no attribute '{name}' or may have been annotated wrongly." + raise VarAttributeError(msg) def __bool__(self) -> bool: """Raise exception if using Var in a boolean context. @@ -1393,10 +1389,11 @@ def __bool__(self) -> bool: # noqa: DAR101 self """ - raise VarTypeError( + msg = ( f"Cannot convert Var {str(self)!r} to bool for use with `if`, `and`, `or`, and `not`. " "Instead use `rx.cond` and bitwise operators `&` (and), `|` (or), `~` (invert)." ) + raise VarTypeError(msg) def __iter__(self) -> Any: """Raise exception if using Var in an iterable context. @@ -1406,9 +1403,8 @@ def __iter__(self) -> Any: # noqa: DAR101 self """ - raise VarTypeError( - f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`." - ) + msg = f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`." + raise VarTypeError(msg) def __contains__(self, _: Any) -> Var: """Override the 'in' operator to alert the user that it is not supported. @@ -1418,9 +1414,10 @@ def __contains__(self, _: Any) -> Var: # noqa: DAR101 self """ - raise VarTypeError( + msg = ( "'in' operator not supported for Var types, use Var.contains() instead." ) + raise VarTypeError(msg) OUTPUT = TypeVar("OUTPUT", bound=Var) @@ -1522,9 +1519,8 @@ def __init_subclass__(cls, **kwargs): ] if not possible_bases: - raise TypeError( - f"LiteralVar subclass {cls} must have a base class that is a subclass of Var and not LiteralVar." - ) + msg = f"LiteralVar subclass {cls} must have a base class that is a subclass of Var and not LiteralVar." + raise TypeError(msg) var_subclasses = [ var_subclass @@ -1533,14 +1529,12 @@ def __init_subclass__(cls, **kwargs): ] if not var_subclasses: - raise TypeError( - f"LiteralVar {cls} must have a base class annotated with `python_types`." - ) + msg = f"LiteralVar {cls} must have a base class annotated with `python_types`." + raise TypeError(msg) if len(var_subclasses) != 1: - raise TypeError( - f"LiteralVar {cls} must have exactly one base class annotated with `python_types`." - ) + msg = f"LiteralVar {cls} must have exactly one base class annotated with `python_types`." + raise TypeError(msg) var_subclass = var_subclasses[0] @@ -1630,9 +1624,8 @@ def _create_literal_var( if isinstance(value, range): return ArrayVar.range(value.start, value.stop, value.step) - raise TypeError( - f"Unsupported type {type(value)} for LiteralVar. Tried to create a LiteralVar from {value}." - ) + msg = f"Unsupported type {type(value)} for LiteralVar. Tried to create a LiteralVar from {value}." + raise TypeError(msg) if not TYPE_CHECKING: create = _create_literal_var @@ -1642,9 +1635,8 @@ def __post_init__(self): @property def _var_value(self) -> Any: - raise NotImplementedError( - "LiteralVar subclasses must implement the _var_value property." - ) + msg = "LiteralVar subclasses must implement the _var_value property." + raise NotImplementedError(msg) def json(self) -> str: """Serialize the var to a JSON string. @@ -1652,9 +1644,8 @@ def json(self) -> str: Raises: NotImplementedError: If the method is not implemented. """ - raise NotImplementedError( - "LiteralVar subclasses must implement the json method." - ) + msg = "LiteralVar subclasses must implement the json method." + raise NotImplementedError(msg) @serializers.serializer @@ -1881,10 +1872,11 @@ def delete_property(this: Any): owner.__del__ = delete_property elif name != self._attrname: - raise TypeError( + msg = ( "Cannot assign the same cached_property to two different names " f"({self._attrname!r} and {name!r})." ) + raise TypeError(msg) def __get__(self, instance: Any, owner: type | None = None): """Get the cached property. @@ -1900,9 +1892,8 @@ def __get__(self, instance: Any, owner: type | None = None): TypeError: If the class does not have __set_name__. """ if self._attrname is None: - raise TypeError( - "Cannot use cached_property on a class without __set_name__." - ) + msg = "Cannot use cached_property on a class without __set_name__." + raise TypeError(msg) cached_field_name = "_reflex_cache_" + self._attrname try: unique_id = object.__getattribute__(instance, cached_field_name) @@ -2164,7 +2155,8 @@ def __init__( ) if kwargs: - raise TypeError(f"Unexpected keyword arguments: {tuple(kwargs)}") + msg = f"Unexpected keyword arguments: {tuple(kwargs)}" + raise TypeError(msg) if backend is None: backend = fget.__name__.startswith("_") @@ -2239,9 +2231,8 @@ def _add_static_dep( elif isinstance(dep, str) and dep != "": deps.setdefault(None, set()).add(dep) else: - raise TypeError( - "ComputedVar dependencies must be Var instances or var names (non-empty strings)." - ) + msg = "ComputedVar dependencies must be Var instances or var names (non-empty strings)." + raise TypeError(msg) return deps @override @@ -2282,7 +2273,8 @@ def _replace( if kwargs: unexpected_kwargs = ", ".join(kwargs.keys()) - raise TypeError(f"Unexpected keyword arguments: {unexpected_kwargs}") + msg = f"Unexpected keyword arguments: {unexpected_kwargs}" + raise TypeError(msg) return type(self)(**field_values) @@ -2529,10 +2521,11 @@ def add_dependency(self, objclass: type[BaseState], dep: Var): (objclass.get_full_name(), self._js_expr) ) return - raise VarDependencyError( + msg = ( "ComputedVar dependencies must be Var instances with a state and " f"field name, got {dep!r}." ) + raise VarDependencyError(msg) def _determine_var_type(self) -> type: """Get the type of the var. @@ -2567,8 +2560,6 @@ def fget(self) -> Callable[[BaseState], RETURN_TYPE]: class DynamicRouteVar(ComputedVar[str | list[str]]): """A ComputedVar that represents a dynamic route.""" - pass - async def _default_async_computed_var(_self: BaseState) -> Any: return None @@ -2683,23 +2674,21 @@ async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE: return value return _awaitable_result() - else: - # handle caching - async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE: - if not hasattr(instance, self._cache_attr) or self.needs_update( - instance - ): - # Set cache attr on state instance. - setattr(instance, self._cache_attr, await self.fget(instance)) - # Ensure the computed var gets serialized to redis. - instance._was_touched = True - # Set the last updated timestamp on the state instance. - setattr(instance, self._last_updated_attr, datetime.datetime.now()) - value = getattr(instance, self._cache_attr) - self._check_deprecated_return_type(instance, value) - return value - return _awaitable_result() + # handle caching + async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE: + if not hasattr(instance, self._cache_attr) or self.needs_update(instance): + # Set cache attr on state instance. + setattr(instance, self._cache_attr, await self.fget(instance)) + # Ensure the computed var gets serialized to redis. + instance._was_touched = True + # Set the last updated timestamp on the state instance. + setattr(instance, self._last_updated_attr, datetime.datetime.now()) + value = getattr(instance, self._cache_attr) + self._check_deprecated_return_type(instance, value) + return value + + return _awaitable_result() @property def fget(self) -> Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]]: @@ -2806,10 +2795,12 @@ def computed_var( ComputedVarSignatureError: If the getter function has more than one argument. """ if cache is False and interval is not None: - raise ValueError("Cannot set update interval without caching.") + msg = "Cannot set update interval without caching." + raise ValueError(msg) if cache is False and (deps is not None or auto_deps is False): - raise VarDependencyError("Cannot track dependencies without caching.") + msg = "Cannot track dependencies without caching." + raise VarDependencyError(msg) if fget is not None: sign = inspect.signature(fget) @@ -3031,7 +3022,8 @@ def get_to_operation(var_subclass: type[Var]) -> type[ToOperation]: if saved_var_subclass.var_subclass is var_subclass ] if not possible_classes: - raise ValueError(f"Could not find ToOperation for {var_subclass}.") + msg = f"Could not find ToOperation for {var_subclass}." + raise ValueError(msg) return possible_classes[0] @@ -3172,21 +3164,20 @@ def transform(fn: Callable[[Var], Var]) -> Callable[[Var], Var]: origin = get_origin(return_type) if origin is not Var: - raise TypeError( - f"Expected return type of {fn.__name__} to be a Var, got {origin}." - ) + msg = f"Expected return type of {fn.__name__} to be a Var, got {origin}." + raise TypeError(msg) generic_args = get_args(return_type) if not generic_args: - raise TypeError( - f"Expected Var return type of {fn.__name__} to have a generic type." - ) + msg = f"Expected Var return type of {fn.__name__} to have a generic type." + raise TypeError(msg) generic_type = get_origin(generic_args[0]) or generic_args[0] if generic_type in dispatchers: - raise ValueError(f"Function for {generic_type} already registered.") + msg = f"Function for {generic_type} already registered." + raise ValueError(msg) dispatchers[generic_type] = fn @@ -3215,17 +3206,15 @@ def generic_type_to_actual_type_map( if generic_origin is not actual_origin: if isinstance(generic_origin, TypeVar): return {generic_origin: actual_origin} - raise TypeError( - f"Type mismatch: expected {generic_origin}, got {actual_origin}." - ) + msg = f"Type mismatch: expected {generic_origin}, got {actual_origin}." + raise TypeError(msg) generic_args = get_args(generic_type) actual_args = get_args(actual_type) if len(generic_args) != len(actual_args): - raise TypeError( - f"Number of generic arguments mismatch: expected {len(generic_args)}, got {len(actual_args)}." - ) + msg = f"Number of generic arguments mismatch: expected {len(generic_args)}, got {len(actual_args)}." + raise TypeError(msg) # call recursively for nested generic types and merge the results return { @@ -3326,28 +3315,26 @@ def dispatch( fn_return_origin = get_origin(fn_return) or fn_return if fn_return_origin is not Var: - raise TypeError( - f"Expected return type of {fn.__name__} to be a Var, got {fn_return}." - ) + msg = f"Expected return type of {fn.__name__} to be a Var, got {fn_return}." + raise TypeError(msg) fn_return_generic_args = get_args(fn_return) if not fn_return_generic_args: - raise TypeError(f"Expected generic type of {fn_return} to be a type.") + msg = f"Expected generic type of {fn_return} to be a type." + raise TypeError(msg) arg_origin = get_origin(fn_first_arg_type) or fn_first_arg_type if arg_origin is not Var: - raise TypeError( - f"Expected first argument of {fn.__name__} to be a Var, got {fn_first_arg_type}." - ) + msg = f"Expected first argument of {fn.__name__} to be a Var, got {fn_first_arg_type}." + raise TypeError(msg) arg_generic_args = get_args(fn_first_arg_type) if not arg_generic_args: - raise TypeError( - f"Expected generic type of {fn_first_arg_type} to be a type." - ) + msg = f"Expected generic type of {fn_first_arg_type} to be a type." + raise TypeError(msg) arg_type = arg_generic_args[0] fn_return_type = fn_return_generic_args[0] diff --git a/reflex/vars/datetime.py b/reflex/vars/datetime.py index 04a7b70dfb0..f71172e9c51 100644 --- a/reflex/vars/datetime.py +++ b/reflex/vars/datetime.py @@ -29,7 +29,8 @@ def raise_var_type_error(): Raises: VarTypeError: Cannot compare a datetime object with a non-datetime object. """ - raise VarTypeError("Cannot compare a datetime object with a non-datetime object.") + msg = "Cannot compare a datetime object with a non-datetime object." + raise VarTypeError(msg) class DateTimeVar(Var[DATETIME_T], python_types=(datetime, date)): diff --git a/reflex/vars/dep_tracking.py b/reflex/vars/dep_tracking.py index 9fddc67e46e..77a26a2e7e4 100644 --- a/reflex/vars/dep_tracking.py +++ b/reflex/vars/dep_tracking.py @@ -106,9 +106,8 @@ def load_attr_or_method(self, instruction: dis.Instruction) -> None: from .base import ComputedVar if instruction.argval in self.INVALID_NAMES: - raise VarValueError( - f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`." - ) + msg = f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`." + raise VarValueError(msg) if instruction.argval == "get_state": # Special case: arbitrary state access requested. self.scan_status = ScanStatus.GETTING_STATE @@ -193,37 +192,32 @@ def handle_getting_state(self, instruction: dis.Instruction) -> None: from reflex.state import BaseState if instruction.opname == "LOAD_FAST": - raise VarValueError( - f"Dependency detection cannot identify get_state class from local var {instruction.argval}." - ) + msg = f"Dependency detection cannot identify get_state class from local var {instruction.argval}." + raise VarValueError(msg) if isinstance(self.func, CodeType): - raise VarValueError( - "Dependency detection cannot identify get_state class from a code object." - ) + msg = "Dependency detection cannot identify get_state class from a code object." + raise VarValueError(msg) if instruction.opname == "LOAD_GLOBAL": # Special case: referencing state class from global scope. try: self._getting_state_class = self._get_globals()[instruction.argval] except (ValueError, KeyError) as ve: - raise VarValueError( - f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals." - ) from ve + msg = f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals." + raise VarValueError(msg) from ve elif instruction.opname == "LOAD_DEREF": # Special case: referencing state class from closure. try: self._getting_state_class = self._get_closure()[instruction.argval] except (ValueError, KeyError) as ve: - raise VarValueError( - f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?" - ) from ve + msg = f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?" + raise VarValueError(msg) from ve elif instruction.opname == "STORE_FAST": # Storing the result of get_state in a local variable. if not isinstance(self._getting_state_class, type) or not issubclass( self._getting_state_class, BaseState ): - raise VarValueError( - f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`." - ) + msg = f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`." + raise VarValueError(msg) self.tracked_locals[instruction.argval] = self._getting_state_class self.scan_status = ScanStatus.SCANNING self._getting_state_class = None @@ -242,9 +236,8 @@ def _eval_var(self) -> Var: positions0 = self._getting_var_instructions[0].positions positions1 = self._getting_var_instructions[-1].positions if module is None or positions0 is None or positions1 is None: - raise VarValueError( - f"Cannot determine the source code for the var in {self.func!r}." - ) + msg = f"Cannot determine the source code for the var in {self.func!r}." + raise VarValueError(msg) start_line = positions0.lineno start_column = positions0.col_offset end_line = positions1.end_lineno @@ -255,9 +248,8 @@ def _eval_var(self) -> Var: or end_line is None or end_column is None ): - raise VarValueError( - f"Cannot determine the source code for the var in {self.func!r}." - ) + msg = f"Cannot determine the source code for the var in {self.func!r}." + raise VarValueError(msg) source = inspect.getsource(module).splitlines(True)[start_line - 1 : end_line] # Create a python source string snippet. if len(source) > 1: @@ -292,9 +284,8 @@ def handle_getting_var(self, instruction: dis.Instruction) -> None: the_var = self._eval_var() the_var_data = the_var._get_all_var_data() if the_var_data is None: - raise VarValueError( - f"Cannot determine the source code for the var in {self.func!r}." - ) + msg = f"Cannot determine the source code for the var in {self.func!r}." + raise VarValueError(msg) self.dependencies.setdefault(the_var_data.state, set()).add( the_var_data.field_name ) diff --git a/reflex/vars/number.py b/reflex/vars/number.py index 350fcbdf5a1..62323faa653 100644 --- a/reflex/vars/number.py +++ b/reflex/vars/number.py @@ -53,9 +53,8 @@ def raise_unsupported_operand_types( Raises: VarTypeError: The operand types are unsupported. """ - raise VarTypeError( - f"Unsupported Operand type(s) for {operator}: {', '.join(t.__name__ for t in operands_types)}" - ) + msg = f"Unsupported Operand type(s) for {operator}: {', '.join(t.__name__ for t in operands_types)}" + raise VarTypeError(msg) class NumberVar(Var[NUMBER_T], python_types=(int, float, decimal.Decimal)): @@ -486,10 +485,11 @@ def __format__(self, format_spec: str) -> str: ) if format_spec: - raise VarValueError( + msg = ( f"Unknown format code '{format_spec}' for object of type 'NumberVar'. It is only supported to use ',', '_', and '.f' for float numbers." "If possible, use computed variables instead: https://reflex.dev/docs/vars/computed-vars/" ) + raise VarValueError(msg) return super().__format__(format_spec) @@ -961,9 +961,8 @@ def json(self) -> str: if isinstance(self._var_value, decimal.Decimal): return json.dumps(float(self._var_value)) if math.isinf(self._var_value) or math.isnan(self._var_value): - raise PrimitiveUnserializableToJSONError( - f"No valid JSON representation for {self}" - ) + msg = f"No valid JSON representation for {self}" + raise PrimitiveUnserializableToJSONError(msg) return json.dumps(self._var_value) def __hash__(self) -> int: diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 6b5fb754a9e..3d514926624 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -333,13 +333,13 @@ def __getattr__(self, name: str) -> Var: ): attribute_type = get_attribute_access_type(var_type, name) if attribute_type is None: - raise VarAttributeError( + msg = ( f"The State var `{self!s}` of type {escape(str(self._var_type))} has no attribute '{name}' or may have been annotated " f"wrongly." ) + raise VarAttributeError(msg) return ObjectItemOperation.create(self, name, attribute_type).guess_type() - else: - return ObjectItemOperation.create(self, name).guess_type() + return ObjectItemOperation.create(self, name).guess_type() def contains(self, key: Var | Any) -> BooleanVar: """Check if the object contains a key. @@ -413,9 +413,8 @@ def json(self) -> str: key = LiteralVar.create(key) value = LiteralVar.create(value) if not isinstance(key, LiteralVar) or not isinstance(value, LiteralVar): - raise TypeError( - "The keys and values of the object must be literal vars to get the JSON representation." - ) + msg = "The keys and values of the object must be literal vars to get the JSON representation." + raise TypeError(msg) keys_and_values.append(f"{key.json()}:{value.json()}") return "{" + ", ".join(keys_and_values) + "}" diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index da9fdce513b..210faa5bf62 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -445,9 +445,8 @@ def foreach(self, fn: Any): # get the number of arguments of the function num_args = len(inspect.signature(fn).parameters) if num_args > 1: - raise VarTypeError( - "The function passed to foreach should take at most one argument." - ) + msg = "The function passed to foreach should take at most one argument." + raise VarTypeError(msg) if num_args == 0: return_value = fn() @@ -535,9 +534,8 @@ def json(self) -> str: for element in self._var_value: element_var = LiteralVar.create(element) if not isinstance(element_var, LiteralVar): - raise TypeError( - f"Array elements must be of type LiteralVar, not {type(element_var)}" - ) + msg = f"Array elements must be of type LiteralVar, not {type(element_var)}" + raise TypeError(msg) elements.append(element_var.json()) return "[" + ", ".join(elements) + "]" @@ -1203,8 +1201,7 @@ def create( only_string = filtered_strings_and_vals[0] if isinstance(only_string, str): return LiteralVar.create(only_string).to(StringVar, _var_type) - else: - return only_string.to(StringVar, only_string._var_type) + return only_string.to(StringVar, only_string._var_type) if len( literal_strings := [ @@ -1400,7 +1397,8 @@ def _cached_var_name(self) -> str: actual_end = start + 1 if start is not None else self._array.length() return str(self._array[actual_start:actual_end].reverse()[::-step]) if step == 0: - raise ValueError("slice step cannot be zero") + msg = "slice step cannot be zero" + raise ValueError(msg) return f"{self._array!s}.slice({normalized_start!s}, {normalized_end!s}).filter((_, i) => i % {step!s} === 0)" actual_start_reverse = end + 1 if end is not None else 0 @@ -1866,13 +1864,15 @@ def json(self) -> str: (self._var_value.color, self._var_value.alpha, self._var_value.shade), ) if color is None or alpha is None or shade is None: - raise TypeError("Cannot serialize color that contains non-literal vars.") + msg = "Cannot serialize color that contains non-literal vars." + raise TypeError(msg) if ( not isinstance(color, str) or not isinstance(alpha, bool) or not isinstance(shade, int) ): - raise TypeError("Color is not a valid color.") + msg = "Color is not a valid color." + raise TypeError(msg) return f"var(--{color}-{'a' if alpha else ''}{shade})" diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 00000000000..84f0d7f7aff --- /dev/null +++ b/scripts/__init__.py @@ -0,0 +1 @@ +"""Utility scripts for the project.""" diff --git a/tests/integration/shared/state.py b/tests/integration/shared/state.py index be6aa1f2aed..60b6b5c9ac7 100644 --- a/tests/integration/shared/state.py +++ b/tests/integration/shared/state.py @@ -5,5 +5,3 @@ class SharedState(rx.State): """Shared state class for reflexers using librarys.""" - - pass diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index bdc0d2e762c..9c1a83f60a9 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -229,7 +229,7 @@ def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]: driver.quit() -@pytest.fixture() +@pytest.fixture def token(background_task: AppHarness, driver: WebDriver) -> str: """Get a function that returns the active token. diff --git a/tests/integration/test_client_storage.py b/tests/integration/test_client_storage.py index 0d45f74f376..a4ef1938dc9 100644 --- a/tests/integration/test_client_storage.py +++ b/tests/integration/test_client_storage.py @@ -2,7 +2,7 @@ from __future__ import annotations -import time +import asyncio from collections.abc import Generator import pytest @@ -175,7 +175,7 @@ def driver(client_side: AppHarness) -> Generator[WebDriver, None, None]: driver.quit() -@pytest.fixture() +@pytest.fixture def local_storage(driver: WebDriver) -> Generator[utils.LocalStorage, None, None]: """Get an instance of the local storage helper. @@ -190,7 +190,7 @@ def local_storage(driver: WebDriver) -> Generator[utils.LocalStorage, None, None ls.clear() -@pytest.fixture() +@pytest.fixture def session_storage(driver: WebDriver) -> Generator[utils.SessionStorage, None, None]: """Get an instance of the session storage helper. @@ -438,7 +438,7 @@ def set_sub_sub(var: str, value: str): "secure": False, "value": "c3%20value", } - time.sleep(2) # wait for c3 to expire + await asyncio.sleep(2) # wait for c3 to expire if not isinstance(driver, Firefox): # Note: Firefox does not remove expired cookies Bug 576347 assert f"{sub_state_name}.c3" not in cookie_info_map(driver) @@ -692,10 +692,7 @@ async def get_sub_state(): _substate_key(token or "", sub_state_name) ) state = root_state.substates[client_side.get_state_name("_client_side_state")] - sub_state = state.substates[ - client_side.get_state_name("_client_side_sub_state") - ] - return sub_state + return state.substates[client_side.get_state_name("_client_side_sub_state")] async def poll_for_c1_set(): sub_state = await get_sub_state() diff --git a/tests/integration/test_component_state.py b/tests/integration/test_component_state.py index ca0cc55b18e..d1078b8eb07 100644 --- a/tests/integration/test_component_state.py +++ b/tests/integration/test_component_state.py @@ -103,7 +103,7 @@ def index(): ) -@pytest.fixture() +@pytest.fixture def component_state_app(tmp_path) -> Generator[AppHarness, None, None]: """Start ComponentStateApp app at tmp_path via AppHarness. diff --git a/tests/integration/test_computed_vars.py b/tests/integration/test_computed_vars.py index bca75270f22..c0d7e84e95f 100644 --- a/tests/integration/test_computed_vars.py +++ b/tests/integration/test_computed_vars.py @@ -2,7 +2,7 @@ from __future__ import annotations -import time +import asyncio from collections.abc import Generator import pytest @@ -160,7 +160,7 @@ def driver(computed_vars: AppHarness) -> Generator[WebDriver, None, None]: driver.quit() -@pytest.fixture() +@pytest.fixture def token(computed_vars: AppHarness, driver: WebDriver) -> str: """Get a function that returns the active token. @@ -267,7 +267,7 @@ async def test_computed_vars( with pytest.raises(TimeoutError): _ = computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0") - time.sleep(10) + await asyncio.sleep(10) assert count3.text == "0" assert depends_on_count3.text == "0" mark_dirty.click() diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index e6b27617699..4f87bef6c9a 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -58,7 +58,7 @@ def simulate_compile_context(request) -> constants.CompileContext: return request.param -@pytest.fixture() +@pytest.fixture def connection_banner( tmp_path, simulate_compile_context: constants.CompileContext, diff --git a/tests/integration/test_deploy_url.py b/tests/integration/test_deploy_url.py index 9123fdb58db..fd44560f13c 100644 --- a/tests/integration/test_deploy_url.py +++ b/tests/integration/test_deploy_url.py @@ -21,6 +21,7 @@ class State(rx.State): def goto_self(self): if (deploy_url := rx.config.get_config().deploy_url) is not None: return rx.redirect(deploy_url) + return None def index(): return rx.fragment( @@ -50,7 +51,7 @@ def deploy_url_sample( yield harness -@pytest.fixture() +@pytest.fixture def driver(deploy_url_sample: AppHarness) -> Generator[WebDriver, None, None]: """WebDriver fixture for testing deploy_url. diff --git a/tests/integration/test_dynamic_components.py b/tests/integration/test_dynamic_components.py index 990333f116b..c26cdca9142 100644 --- a/tests/integration/test_dynamic_components.py +++ b/tests/integration/test_dynamic_components.py @@ -120,7 +120,8 @@ def poll_for_result( except exception: attempts += 1 time.sleep(seconds_between_attempts) - raise AssertionError("Function did not return a value") + msg = "Function did not return a value" + raise AssertionError(msg) @pytest.fixture diff --git a/tests/integration/test_dynamic_routes.py b/tests/integration/test_dynamic_routes.py index 9fa933a84d2..a7b6e8ec31e 100644 --- a/tests/integration/test_dynamic_routes.py +++ b/tests/integration/test_dynamic_routes.py @@ -2,7 +2,7 @@ from __future__ import annotations -import time +import asyncio from collections.abc import Callable, Coroutine, Generator from urllib.parse import urlsplit @@ -186,7 +186,7 @@ def driver(dynamic_route: AppHarness) -> Generator[WebDriver, None, None]: driver.quit() -@pytest.fixture() +@pytest.fixture def token(dynamic_route: AppHarness, driver: WebDriver) -> str: """Get the token associated with backend state. @@ -208,7 +208,7 @@ def token(dynamic_route: AppHarness, driver: WebDriver) -> str: return token -@pytest.fixture() +@pytest.fixture def poll_for_order( dynamic_route: AppHarness, token: str ) -> Callable[[list[str]], Coroutine[None, None, None]]: @@ -395,7 +395,7 @@ async def test_render_dynamic_arg( driver.get(f"{dynamic_route.frontend_url}/arg/0") # TODO: drop after flakiness is resolved - time.sleep(3) + await asyncio.sleep(3) def assert_content(expected: str, expect_not: str): ids = [ diff --git a/tests/integration/test_event_actions.py b/tests/integration/test_event_actions.py index d82c4e977fe..ddc9fc1f68f 100644 --- a/tests/integration/test_event_actions.py +++ b/tests/integration/test_event_actions.py @@ -57,7 +57,7 @@ def _get_custom_code(self) -> str | None: }""" def get_event_triggers(self): - return {"on_click": lambda: []} + return {"on_click": rx.event.no_args_event_spec} def index(): return rx.vstack( @@ -218,7 +218,7 @@ def driver(event_action: AppHarness) -> Generator[WebDriver, None, None]: driver.quit() -@pytest.fixture() +@pytest.fixture def token(event_action: AppHarness, driver: WebDriver) -> str: """Get the token associated with backend state. @@ -240,7 +240,7 @@ def token(event_action: AppHarness, driver: WebDriver) -> str: return token -@pytest.fixture() +@pytest.fixture def poll_for_order( event_action: AppHarness, token: str ) -> Callable[[list[str]], Coroutine[None, None, None]]: diff --git a/tests/integration/test_event_chain.py b/tests/integration/test_event_chain.py index 3b510caa09f..3552469f2d5 100644 --- a/tests/integration/test_event_chain.py +++ b/tests/integration/test_event_chain.py @@ -551,10 +551,10 @@ async def _has_all_events(): @pytest.mark.parametrize( - ("button_id",), + "button_id", [ - ("click_yield_interim_value_async",), - ("click_yield_interim_value",), + "click_yield_interim_value_async", + "click_yield_interim_value", ], ) def test_yield_state_update(event_chain: AppHarness, driver: WebDriver, button_id: str): diff --git a/tests/integration/test_exception_handlers.py b/tests/integration/test_exception_handlers.py index 0f3f94108a9..6430bc746a6 100644 --- a/tests/integration/test_exception_handlers.py +++ b/tests/integration/test_exception_handlers.py @@ -23,8 +23,6 @@ def TestApp(): class TestAppConfig(rx.Config): """Config for the TestApp app.""" - pass - class TestAppState(rx.State): """State for the TestApp app.""" @@ -134,10 +132,8 @@ def test_frontend_exception_handler_during_runtime( time.sleep(2) captured_default_handler_output = capsys.readouterr() - assert ( - "induce_frontend_error" in captured_default_handler_output.out - and "ReferenceError" in captured_default_handler_output.out - ) + assert "induce_frontend_error" in captured_default_handler_output.out + assert "ReferenceError" in captured_default_handler_output.out def test_backend_exception_handler_during_runtime( @@ -164,10 +160,8 @@ def test_backend_exception_handler_during_runtime( time.sleep(2) captured_default_handler_output = capsys.readouterr() - assert ( - "divide_by_number" in captured_default_handler_output.out - and "ZeroDivisionError" in captured_default_handler_output.out - ) + assert "divide_by_number" in captured_default_handler_output.out + assert "ZeroDivisionError" in captured_default_handler_output.out def test_frontend_exception_handler_with_react( diff --git a/tests/integration/test_form_submit.py b/tests/integration/test_form_submit.py index a2f0a32e956..0e0cbb50297 100644 --- a/tests/integration/test_form_submit.py +++ b/tests/integration/test_form_submit.py @@ -1,7 +1,7 @@ """Integration tests for forms.""" +import asyncio import functools -import time from collections.abc import Generator import pytest @@ -217,7 +217,7 @@ async def test_submit(driver, form_submit: AppHarness): debounce_input = driver.find_element(by, "debounce_input") debounce_input.send_keys("bar baz") - time.sleep(1) + await asyncio.sleep(1) prev_url = driver.current_url diff --git a/tests/integration/test_input.py b/tests/integration/test_input.py index 600ad2d68b6..3be518af389 100644 --- a/tests/integration/test_input.py +++ b/tests/integration/test_input.py @@ -55,7 +55,7 @@ def index(): ) -@pytest.fixture() +@pytest.fixture def fully_controlled_input(tmp_path) -> Generator[AppHarness, None, None]: """Start FullyControlledInput app at tmp_path via AppHarness. diff --git a/tests/integration/test_large_state.py b/tests/integration/test_large_state.py index a9a8ff2ec6f..701b8d33d54 100644 --- a/tests/integration/test_large_state.py +++ b/tests/integration/test_large_state.py @@ -71,7 +71,8 @@ def test_large_state(var_count: int, tmp_path_factory, benchmark): while button.text != "0": time.sleep(0.1) if time.time() - t > 30.0: - raise TimeoutError("Timeout waiting for initial state") + msg = "Timeout waiting for initial state" + raise TimeoutError(msg) times_clicked = 0 @@ -84,7 +85,8 @@ def round_trip(clicks: int, timeout: float): while button.text != str(times_clicked): time.sleep(0.005) if time.time() - t > timeout: - raise TimeoutError("Timeout waiting for state update") + msg = "Timeout waiting for state update" + raise TimeoutError(msg) benchmark(round_trip, clicks=10, timeout=30.0) finally: diff --git a/tests/integration/test_lifespan.py b/tests/integration/test_lifespan.py index c084c48bcec..d66c42f65a8 100644 --- a/tests/integration/test_lifespan.py +++ b/tests/integration/test_lifespan.py @@ -120,7 +120,7 @@ def mount_cached_fastapi(request: pytest.FixtureRequest) -> bool: return request.param -@pytest.fixture() +@pytest.fixture def lifespan_app( tmp_path, mount_api_transformer: bool, mount_cached_fastapi: bool ) -> Generator[AppHarness, None, None]: diff --git a/tests/integration/test_login_flow.py b/tests/integration/test_login_flow.py index 58e5e79a862..808ec74fbd6 100644 --- a/tests/integration/test_login_flow.py +++ b/tests/integration/test_login_flow.py @@ -67,7 +67,7 @@ def login_sample(tmp_path_factory) -> Generator[AppHarness, None, None]: yield harness -@pytest.fixture() +@pytest.fixture def driver(login_sample: AppHarness) -> Generator[WebDriver, None, None]: """Get an instance of the browser open to the login_sample app. @@ -85,7 +85,7 @@ def driver(login_sample: AppHarness) -> Generator[WebDriver, None, None]: driver.quit() -@pytest.fixture() +@pytest.fixture def local_storage(driver: WebDriver) -> Generator[utils.LocalStorage, None, None]: """Get an instance of the local storage helper. diff --git a/tests/integration/test_media.py b/tests/integration/test_media.py index 7b4b89bf6cc..7c58486840e 100644 --- a/tests/integration/test_media.py +++ b/tests/integration/test_media.py @@ -75,7 +75,7 @@ def index(): ) -@pytest.fixture() +@pytest.fixture def media_app(tmp_path) -> Generator[AppHarness, None, None]: """Start MediaApp app at tmp_path via AppHarness. diff --git a/tests/integration/test_memo.py b/tests/integration/test_memo.py index b5343b0fa41..5bafac032b5 100644 --- a/tests/integration/test_memo.py +++ b/tests/integration/test_memo.py @@ -58,7 +58,7 @@ def index() -> rx.Component: app.add_page(index) -@pytest.fixture() +@pytest.fixture def memo_app(tmp_path) -> Generator[AppHarness, None, None]: """Start MemoApp app at tmp_path via AppHarness. diff --git a/tests/integration/test_navigation.py b/tests/integration/test_navigation.py index 16074df6e9c..421d864c8a6 100644 --- a/tests/integration/test_navigation.py +++ b/tests/integration/test_navigation.py @@ -40,7 +40,7 @@ def internal(): return rx.text("Internal") -@pytest.fixture() +@pytest.fixture def navigation_app(tmp_path) -> Generator[AppHarness, None, None]: """Start NavigationApp app at tmp_path via AppHarness. diff --git a/tests/integration/test_state_inheritance.py b/tests/integration/test_state_inheritance.py index 1bcc9cc53c9..608f6c1c1fc 100644 --- a/tests/integration/test_state_inheritance.py +++ b/tests/integration/test_state_inheritance.py @@ -242,7 +242,7 @@ def driver(state_inheritance: AppHarness) -> Generator[WebDriver, None, None]: driver.quit() -@pytest.fixture() +@pytest.fixture def token(state_inheritance: AppHarness, driver: WebDriver) -> str: """Get a function that returns the active token. diff --git a/tests/integration/test_tailwind.py b/tests/integration/test_tailwind.py index 53090050585..fff5f330d6f 100644 --- a/tests/integration/test_tailwind.py +++ b/tests/integration/test_tailwind.py @@ -75,7 +75,7 @@ def tailwind_version(request) -> int: return request.param -@pytest.fixture() +@pytest.fixture def tailwind_app(tmp_path, tailwind_version) -> Generator[AppHarness, None, None]: """Start TailwindApp app at tmp_path via AppHarness with tailwind disabled via config. diff --git a/tests/integration/test_upload.py b/tests/integration/test_upload.py index d4937af9414..e0568019d8b 100644 --- a/tests/integration/test_upload.py +++ b/tests/integration/test_upload.py @@ -341,7 +341,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): target_file.write_text(exp_contents) upload_box.send_keys(str(target_file)) - time.sleep(0.2) + await asyncio.sleep(0.2) # check that the selected files are displayed selected_files = driver.find_element(By.ID, "selected_files") diff --git a/tests/integration/test_var_operations.py b/tests/integration/test_var_operations.py index ff758cb756b..3e79aac20bf 100644 --- a/tests/integration/test_var_operations.py +++ b/tests/integration/test_var_operations.py @@ -591,7 +591,7 @@ def index(): ), rx.box( rx.foreach( - LiteralVar.create(list(range(0, 3))).to(ArrayVar, list[int]), + LiteralVar.create(list(range(3))).to(ArrayVar, list[int]), lambda x: rx.foreach( ArrayVar.range(x), lambda y: rx.text(VarOperationState.list1[y], as_="p"), diff --git a/tests/integration/tests_playwright/test_appearance.py b/tests/integration/tests_playwright/test_appearance.py index 4885caa128e..9af5fa38b73 100644 --- a/tests/integration/tests_playwright/test_appearance.py +++ b/tests/integration/tests_playwright/test_appearance.py @@ -72,7 +72,7 @@ def index(): ) -@pytest.fixture() +@pytest.fixture def light_mode_app(tmp_path_factory) -> Generator[AppHarness, None, None]: """Start DefaultLightMode app at tmp_path via AppHarness. @@ -91,7 +91,7 @@ def light_mode_app(tmp_path_factory) -> Generator[AppHarness, None, None]: yield harness -@pytest.fixture() +@pytest.fixture def dark_mode_app(tmp_path_factory) -> Generator[AppHarness, None, None]: """Start DefaultDarkMode app at tmp_path via AppHarness. @@ -110,7 +110,7 @@ def dark_mode_app(tmp_path_factory) -> Generator[AppHarness, None, None]: yield harness -@pytest.fixture() +@pytest.fixture def system_mode_app(tmp_path_factory) -> Generator[AppHarness, None, None]: """Start DefaultSystemMode app at tmp_path via AppHarness. @@ -129,7 +129,7 @@ def system_mode_app(tmp_path_factory) -> Generator[AppHarness, None, None]: yield harness -@pytest.fixture() +@pytest.fixture def color_toggle_app(tmp_path_factory) -> Generator[AppHarness, None, None]: """Start ColorToggle app at tmp_path via AppHarness. diff --git a/tests/integration/tests_playwright/test_datetime_operations.py b/tests/integration/tests_playwright/test_datetime_operations.py index cd394b98d6c..9f86f9401ee 100644 --- a/tests/integration/tests_playwright/test_datetime_operations.py +++ b/tests/integration/tests_playwright/test_datetime_operations.py @@ -41,7 +41,7 @@ def index(): ) -@pytest.fixture() +@pytest.fixture def datetime_operations_app(tmp_path_factory) -> Generator[AppHarness, None, None]: """Start Table app at tmp_path via AppHarness. diff --git a/tests/integration/tests_playwright/test_link_hover.py b/tests/integration/tests_playwright/test_link_hover.py index 477700026d3..c4a473f410f 100644 --- a/tests/integration/tests_playwright/test_link_hover.py +++ b/tests/integration/tests_playwright/test_link_hover.py @@ -25,7 +25,7 @@ def index(): app.add_page(index, "/") -@pytest.fixture() +@pytest.fixture def link_app(tmp_path_factory) -> Generator[AppHarness, None, None]: with AppHarness.create( root=tmp_path_factory.mktemp("link_app"), diff --git a/tests/integration/tests_playwright/test_table.py b/tests/integration/tests_playwright/test_table.py index 845f314495f..21c0d0b29e5 100644 --- a/tests/integration/tests_playwright/test_table.py +++ b/tests/integration/tests_playwright/test_table.py @@ -55,7 +55,7 @@ def index(): ) -@pytest.fixture() +@pytest.fixture def table_app(tmp_path_factory) -> Generator[AppHarness, None, None]: """Start Table app at tmp_path via AppHarness. diff --git a/tests/test_node_version.py b/tests/test_node_version.py index 3621ddede48..dd524cfc9e9 100644 --- a/tests/test_node_version.py +++ b/tests/test_node_version.py @@ -32,7 +32,7 @@ def index(): return rx.heading("Node Version check v", TestNodeVersionState.node_version) -@pytest.fixture() +@pytest.fixture def node_version_app(tmp_path) -> Generator[AppHarness, Any, None]: """Fixture to start TestNodeVersionApp app at tmp_path via AppHarness. @@ -63,8 +63,7 @@ def get_latest_node_version(): # Assuming the first entry in the API response is the most recent version if versions: - latest_version = versions[0]["version"] - return latest_version + return versions[0]["version"] return None assert node_version_app.frontend_url is not None diff --git a/tests/units/assets/test_assets.py b/tests/units/assets/test_assets.py index c23ec27a405..74d0e42fd08 100644 --- a/tests/units/assets/test_assets.py +++ b/tests/units/assets/test_assets.py @@ -38,7 +38,7 @@ def test_shared_asset() -> None: @pytest.mark.parametrize( - "path,shared", + ("path", "shared"), [ pytest.param("non_existing_file", True), pytest.param("non_existing_file", False), diff --git a/tests/units/compiler/test_compiler.py b/tests/units/compiler/test_compiler.py index 26af944bfed..339fadbfe3b 100644 --- a/tests/units/compiler/test_compiler.py +++ b/tests/units/compiler/test_compiler.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize( - "fields,test_default,test_rest", + ("fields", "test_default", "test_rest"), [ ( [ImportVar(tag="axios", is_default=True)], @@ -51,7 +51,7 @@ def test_compile_import_statement( @pytest.mark.parametrize( - "import_dict,test_dicts", + ("import_dict", "test_dicts"), [ ({}, []), ( diff --git a/tests/units/components/base/test_bare.py b/tests/units/components/base/test_bare.py index 6ae1e4db643..d36813badf2 100644 --- a/tests/units/components/base/test_bare.py +++ b/tests/units/components/base/test_bare.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize( - "contents,expected", + ("contents", "expected"), [ ("hello", '"hello"'), ("{}", '"{}"'), diff --git a/tests/units/components/base/test_script.py b/tests/units/components/base/test_script.py index 2db63130746..0c285c18fc0 100644 --- a/tests/units/components/base/test_script.py +++ b/tests/units/components/base/test_script.py @@ -39,17 +39,14 @@ class EvState(BaseState): @rx.event def on_ready(self): """Empty event handler.""" - pass @rx.event def on_load(self): """Empty event handler.""" - pass @rx.event def on_error(self): """Empty event handler.""" - pass def test_script_event_handler(): diff --git a/tests/units/components/core/test_colors.py b/tests/units/components/core/test_colors.py index e6bc3ba797b..d23a5079cd9 100644 --- a/tests/units/components/core/test_colors.py +++ b/tests/units/components/core/test_colors.py @@ -29,7 +29,7 @@ def create_color_var(color): @pytest.mark.parametrize( - "color, expected, expected_type", + ("color", "expected", "expected_type"), [ (create_color_var(rx.color("mint")), '"var(--mint-7)"', Color), (create_color_var(rx.color("mint", 3)), '"var(--mint-3)"', Color), @@ -79,7 +79,7 @@ def test_color(color, expected, expected_type: type[str] | type[Color]): @pytest.mark.parametrize( - "cond_var, expected", + ("cond_var", "expected"), [ ( rx.cond(True, rx.color("mint"), rx.color("tomato", 5)), @@ -119,7 +119,7 @@ def test_color_with_conditionals(cond_var, expected): @pytest.mark.parametrize( - "color, expected", + ("color", "expected"), [ (create_color_var(rx.color("red")), '"var(--red-7)"'), (create_color_var(rx.color("green", shade=1)), '"var(--green-1)"'), diff --git a/tests/units/components/core/test_cond.py b/tests/units/components/core/test_cond.py index fb999d9eca6..9796bad760b 100644 --- a/tests/units/components/core/test_cond.py +++ b/tests/units/components/core/test_cond.py @@ -69,7 +69,7 @@ def test_validate_cond(cond_state: BaseState): @pytest.mark.parametrize( - "c1, c2", + ("c1", "c2"), [ (True, False), (32, 0), diff --git a/tests/units/components/core/test_debounce.py b/tests/units/components/core/test_debounce.py index 61a65cf70aa..84017f38240 100644 --- a/tests/units/components/core/test_debounce.py +++ b/tests/units/components/core/test_debounce.py @@ -44,7 +44,6 @@ def on_change(self, v: str): Args: v: The changed value. """ - pass def test_render_child_props(): @@ -57,7 +56,8 @@ def test_render_child_props(): on_change=S.on_change, ) )._render() - assert "css" in tag.props and isinstance(tag.props["css"], rx.vars.Var) + assert "css" in tag.props + assert isinstance(tag.props["css"], rx.vars.Var) for prop in ["foo", "bar", "baz", "quuc"]: assert prop in str(tag.props["css"]) assert tag.props["value"].equals( @@ -151,7 +151,8 @@ def test_render_child_props_recursive(): ), force_notify_by_enter=False, )._render() - assert "css" in tag.props and isinstance(tag.props["css"], rx.vars.Var) + assert "css" in tag.props + assert isinstance(tag.props["css"], rx.vars.Var) for prop in ["foo", "bar", "baz", "quuc"]: assert prop in str(tag.props["css"]) assert tag.props["value"].equals(LiteralVar.create("outer")) diff --git a/tests/units/components/core/test_foreach.py b/tests/units/components/core/test_foreach.py index 4e9e51c476c..b2115ac8ae4 100644 --- a/tests/units/components/core/test_foreach.py +++ b/tests/units/components/core/test_foreach.py @@ -144,7 +144,7 @@ def display_color_index_tuple(color): @pytest.mark.parametrize( - "state_var, render_fn, render_dict", + ("state_var", "render_fn", "render_dict"), [ ( ForEachState.colors_list, diff --git a/tests/units/components/core/test_match.py b/tests/units/components/core/test_match.py index ce17f84f22b..9b3f684c7ee 100644 --- a/tests/units/components/core/test_match.py +++ b/tests/units/components/core/test_match.py @@ -89,7 +89,7 @@ def test_match_components(): @pytest.mark.parametrize( - "cases, expected", + ("cases", "expected"), [ ( ( @@ -158,7 +158,8 @@ def test_match_on_component_without_default(): assert isinstance(match_comp, Component) default = match_comp.render()["children"][0]["default"] - assert isinstance(default, dict) and default["name"] == Fragment.__name__ + assert isinstance(default, dict) + assert default["name"] == Fragment.__name__ def test_match_on_var_no_default(): @@ -240,7 +241,7 @@ def test_match_case_tuple_elements(match_case): @pytest.mark.parametrize( - "cases, error_msg", + ("cases", "error_msg"), [ ( ( diff --git a/tests/units/components/core/test_upload.py b/tests/units/components/core/test_upload.py index efade7b63f3..3b03362d6e4 100644 --- a/tests/units/components/core/test_upload.py +++ b/tests/units/components/core/test_upload.py @@ -24,7 +24,6 @@ def drop_handler(self, files: Any): Args: files: The files dropped. """ - pass @event def not_drop_handler(self, not_files: Any): @@ -33,7 +32,6 @@ def not_drop_handler(self, not_files: Any): Args: not_files: The files dropped. """ - pass def test_cancel_upload(): diff --git a/tests/units/components/datadisplay/test_code.py b/tests/units/components/datadisplay/test_code.py index db0120fe1c1..85d60cf60b6 100644 --- a/tests/units/components/datadisplay/test_code.py +++ b/tests/units/components/datadisplay/test_code.py @@ -4,7 +4,7 @@ @pytest.mark.parametrize( - "theme, expected", + ("theme", "expected"), [(Theme.one_light, "oneLight"), (Theme.one_dark, "oneDark")], ) def test_code_light_dark_theme(theme, expected): diff --git a/tests/units/components/datadisplay/test_datatable.py b/tests/units/components/datadisplay/test_datatable.py index 2b47669222d..902e37575ef 100644 --- a/tests/units/components/datadisplay/test_datatable.py +++ b/tests/units/components/datadisplay/test_datatable.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize( - "data_table_state,expected", + ("data_table_state", "expected"), [ pytest.param( { @@ -73,7 +73,7 @@ def test_invalid_props(props): @pytest.mark.parametrize( - "fixture, err_msg, is_data_frame", + ("fixture", "err_msg", "is_data_frame"), [ ( "data_table_state2", @@ -114,11 +114,11 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram def test_serialize_dataframe(): """Test if dataframe is serialized correctly.""" - df = pd.DataFrame( + simple_dataframe = pd.DataFrame( [["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"], # pyright: ignore [reportArgumentType] ) - value = serialize(df) - assert value == serialize_dataframe(df) + value = serialize(simple_dataframe) + assert value == serialize_dataframe(simple_dataframe) assert isinstance(value, dict) assert tuple(value) == ("columns", "data") diff --git a/tests/units/components/datadisplay/test_shiki_code.py b/tests/units/components/datadisplay/test_shiki_code.py index e1c7984f1df..05553815409 100644 --- a/tests/units/components/datadisplay/test_shiki_code.py +++ b/tests/units/components/datadisplay/test_shiki_code.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize( - "library, fns, expected_output, raises_exception", + ("library", "fns", "expected_output", "raises_exception"), [ ("some_library", ["function_one"], ["function_one"], False), ("some_library", [123], None, True), @@ -47,7 +47,7 @@ def test_create_transformer(library, fns, expected_output, raises_exception): @pytest.mark.parametrize( - "code_block, children, props, expected_first_child, expected_styles", + ("code_block", "children", "props", "expected_first_child", "expected_styles"), [ ("print('Hello')", ["print('Hello')"], {}, "print('Hello')", {}), ( @@ -106,7 +106,7 @@ def test_create_shiki_code_block( @pytest.mark.parametrize( - "children, props, expected_transformers, expected_button_type", + ("children", "props", "expected_transformers", "expected_button_type"), [ (["print('Hello')"], {"use_transformers": True}, [ShikiJsTransformer], None), (["print('Hello')"], {"can_copy": True}, None, Button), @@ -151,7 +151,7 @@ def test_create_shiki_high_level_code_block( @pytest.mark.parametrize( - "children, props", + ("children", "props"), [ (["print('Hello')"], {"theme": "dark"}), (["print('Hello')"], {"language": "javascript"}), diff --git a/tests/units/components/graphing/test_plotly.py b/tests/units/components/graphing/test_plotly.py index 69b046bea34..d060a0fd3d3 100644 --- a/tests/units/components/graphing/test_plotly.py +++ b/tests/units/components/graphing/test_plotly.py @@ -14,7 +14,8 @@ def plotly_fig() -> go.Figure: A random plotly figure. """ # Generate random data. - data = np.random.randint(0, 10, size=(10, 4)) + rng = np.random.default_rng() + data = rng.integers(0, 10, size=(10, 4)) trace = go.Scatter( x=list(range(len(data))), y=data[:, 0], mode="lines", name="Trace 1" ) diff --git a/tests/units/components/markdown/test_markdown.py b/tests/units/components/markdown/test_markdown.py index ce3c8f5b3a1..5074b364401 100644 --- a/tests/units/components/markdown/test_markdown.py +++ b/tests/units/components/markdown/test_markdown.py @@ -56,7 +56,7 @@ def code_block_markdown(*children, **props): @pytest.mark.parametrize( - "fn_body, fn_args, explicit_return, expected", + ("fn_body", "fn_args", "explicit_return", "expected"), [ ( None, @@ -141,7 +141,7 @@ def test_create_map_fn_var_subclass(cls, fn_body, fn_args, explicit_return, expe @pytest.mark.parametrize( - "key,component_map, expected", + ("key", "component_map", "expected"), [ ( "code", diff --git a/tests/units/components/media/test_image.py b/tests/units/components/media/test_image.py index 519ca735e13..94f33e32e56 100644 --- a/tests/units/components/media/test_image.py +++ b/tests/units/components/media/test_image.py @@ -16,7 +16,8 @@ def pil_image() -> Img: Returns: A random PIL image. """ - imarray = np.random.rand(100, 100, 3) * 255 + rng = np.random.default_rng() + imarray = rng.random((100, 100, 3)) * 255 return PIL.Image.fromarray(imarray.astype("uint8")).convert("RGBA") # pyright: ignore [reportAttributeAccessIssue] diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 90185d00193..f8cc7ac6864 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -302,7 +302,7 @@ def test_create_component(component1): @pytest.mark.parametrize( - "prop_name,var,expected", + ("prop_name", "var", "expected"), [ pytest.param( "text", @@ -539,7 +539,7 @@ def test_get_props(component1, component2): @pytest.mark.parametrize( - "text,number", + ("text", "number"), [ ("", 0), ("test", 1), @@ -560,7 +560,7 @@ def test_valid_props(component1, text: str, number: int): @pytest.mark.parametrize( - "text,number", [("", "bad_string"), (13, 1), ("test", [1, 2, 3])] + ("text", "number"), [("", "bad_string"), (13, 1), ("test", [1, 2, 3])] ) def test_invalid_prop_type(component1, text: str, number: int): """Test that an invalid prop type raises an error. @@ -674,7 +674,7 @@ def test_component_create_unallowed_types(children, test_component): @pytest.mark.parametrize( - "element, expected", + ("element", "expected"), [ ( (rx.text("first_text"),), @@ -842,7 +842,6 @@ class C1State(BaseState): def mock_handler(self, _e: JavascriptInputEvent, _bravo: dict, _charlie: _Obj): """Mock handler.""" - pass def test_component_event_trigger_arbitrary_args(): @@ -1177,7 +1176,7 @@ def test_component_with_only_valid_children(fixture, request): @pytest.mark.parametrize( - "component,rendered", + ("component", "rendered"), [ (rx.text("hi"), 'jsx(\nRadixThemesText,\n{as:"p"},\n"hi"\n,)'), ( @@ -1306,7 +1305,7 @@ def handler2(self, arg): @pytest.mark.parametrize( ("component", "exp_vars"), - ( + [ pytest.param( Bare.create(TEST_VAR), [TEST_VAR], @@ -1473,7 +1472,7 @@ def handler2(self, arg): [FORMATTED_TEST_VAR_LIST_OF_DICT], id="fstring-list_of_dict", ), - ), + ], ) def test_get_vars(component, exp_vars): comp_vars = sorted(component._get_vars(), key=lambda v: v._js_expr) @@ -1536,8 +1535,6 @@ def test_instantiate_all_components(): class InvalidParentComponent(Component): """Invalid Parent Component.""" - ... - class ValidComponent1(Component): """Test valid component.""" @@ -1548,8 +1545,6 @@ class ValidComponent1(Component): class ValidComponent2(Component): """Test valid component.""" - ... - class ValidComponent3(Component): """Test valid component.""" @@ -1566,8 +1561,6 @@ class ValidComponent4(Component): class InvalidComponent(Component): """Test invalid component.""" - ... - valid_component1 = ValidComponent1.create valid_component2 = ValidComponent2.create @@ -1867,8 +1860,8 @@ def get_event_triggers(self) -> dict[str, Any]: return { **super().get_event_triggers(), "on_b": input_event, - "on_d": lambda: [], - "on_e": lambda: [], + "on_d": no_args_event_spec, + "on_e": no_args_event_spec, } class TestComponent(Component): @@ -1902,7 +1895,7 @@ def get_event_triggers(self) -> dict[str, Any]: """ return { **super().get_event_triggers(), - "on_a": lambda: [], + "on_a": no_args_event_spec, } trigger_comp = TriggerComponent.create @@ -1917,13 +1910,13 @@ def get_event_triggers(self) -> dict[str, Any]: @pytest.mark.parametrize( "tags", - ( + [ ["Component"], ["Component", "useState"], [ImportVar(tag="Component")], [ImportVar(tag="Component"), ImportVar(tag="useState")], ["Component", ImportVar(tag="useState")], - ), + ], ) def test_component_add_imports(tags): class BaseComponent(Component): @@ -2232,11 +2225,10 @@ class TriggerState(rx.State): @rx.event def do_something(self): """Sample event handler.""" - pass @pytest.mark.parametrize( - "component, output", + ("component", "output"), [ (rx.box(rx.text("random text")), False), ( diff --git a/tests/units/components/test_component_future_annotations.py b/tests/units/components/test_component_future_annotations.py index 0867a2d378d..74c917e90c6 100644 --- a/tests/units/components/test_component_future_annotations.py +++ b/tests/units/components/test_component_future_annotations.py @@ -20,7 +20,7 @@ def get_event_triggers(self) -> dict[str, Any]: "on_a": lambda e: [e], "on_b": lambda e: [e.target.value], "on_c": lambda e: [], - "on_d": lambda: [], + "on_d": no_args_event_spec, } class TestComponent(Component): diff --git a/tests/units/components/test_props.py b/tests/units/components/test_props.py index 2f146e83c0d..81c37b388cc 100644 --- a/tests/units/components/test_props.py +++ b/tests/units/components/test_props.py @@ -20,7 +20,7 @@ class PropB(NoExtrasAllowedProps): @pytest.mark.parametrize( - "props_class, kwargs, should_raise", + ("props_class", "kwargs", "should_raise"), [ (PropA, {"foo": "value", "bar": "another_value"}, False), (PropA, {"fooz": "value", "bar": "another_value"}, True), @@ -96,7 +96,7 @@ class OptionalFieldProps(PropsBase): @pytest.mark.parametrize( - "props_class, props_kwargs, expected_dict", + ("props_class", "props_kwargs", "expected_dict"), [ # Test single word + snake_case conversion ( diff --git a/tests/units/components/test_tag.py b/tests/units/components/test_tag.py index b3ff2b8edc7..431c04bdf8b 100644 --- a/tests/units/components/test_tag.py +++ b/tests/units/components/test_tag.py @@ -5,7 +5,7 @@ @pytest.mark.parametrize( - "props,test_props", + ("props", "test_props"), [ ({}, []), ({"key-hypen": 1}, ['"key-hypen":1']), @@ -27,7 +27,7 @@ def test_format_props(props: dict[str, Var], test_props: list): @pytest.mark.parametrize( - "prop,valid", + ("prop", "valid"), [ (1, True), (3.14, True), @@ -58,7 +58,7 @@ def test_add_props(): @pytest.mark.parametrize( - "tag,expected", + ("tag", "expected"), [ (Tag(), {"name": "", "contents": "", "props": {}}), (Tag(name="br"), {"name": "br", "contents": "", "props": {}}), diff --git a/tests/units/components/typography/test_markdown.py b/tests/units/components/typography/test_markdown.py index 12f3b0dbe31..0157f972abb 100644 --- a/tests/units/components/typography/test_markdown.py +++ b/tests/units/components/typography/test_markdown.py @@ -5,7 +5,7 @@ @pytest.mark.parametrize( - "tag,expected", + ("tag", "expected"), [ ("h1", "Heading"), ("h2", "Heading"), diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 00bbfb1043b..54e0923c3c6 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -59,13 +59,13 @@ def app_module_mock(monkeypatch) -> mock.Mock: @pytest.fixture(scope="session") -def windows_platform() -> Generator: +def windows_platform() -> bool: """Check if system is windows. - Yields: + Returns: whether system is windows. """ - yield platform.system() == "Windows" + return platform.system() == "Windows" @pytest.fixture @@ -217,7 +217,7 @@ def mutable_state() -> MutableTestState: return MutableTestState() -@pytest.fixture(scope="function") +@pytest.fixture def token() -> str: """Create a token. diff --git a/tests/units/states/mutation.py b/tests/units/states/mutation.py index 3cb41bf3bf5..1d2153d9087 100644 --- a/tests/units/states/mutation.py +++ b/tests/units/states/mutation.py @@ -150,8 +150,6 @@ class CustomVar(rx.Base): class MutableSQLABase(DeclarativeBase): """SQLAlchemy base model for mutable vars.""" - pass - class MutableSQLAModel(MutableSQLABase): """SQLAlchemy model for mutable vars.""" diff --git a/tests/units/states/upload.py b/tests/units/states/upload.py index 6942a430b4b..1c2d32a3bb6 100644 --- a/tests/units/states/upload.py +++ b/tests/units/states/upload.py @@ -20,7 +20,6 @@ async def handle_upload1(self, files: list[rx.UploadFile]): Args: files: The uploaded files. """ - pass class SubUploadState(UploadBaseState): @@ -34,7 +33,6 @@ async def handle_upload(self, files: list[rx.UploadFile]): Args: files: The uploaded files. """ - pass class FileUploadState(State): @@ -49,7 +47,6 @@ async def handle_upload2(self, files): Args: files: The uploaded files. """ - pass async def multi_handle_upload(self, files: list[rx.UploadFile]): """Handle the upload of a file. @@ -75,14 +72,11 @@ async def bg_upload(self, files: list[rx.UploadFile]): Args: files: The uploaded files. """ - pass class FileStateBase1(State): """The base state for a child FileUploadState.""" - pass - class ChildFileUploadState(FileStateBase1): """The child state for uploading a file.""" @@ -96,7 +90,6 @@ async def handle_upload2(self, files): Args: files: The uploaded files. """ - pass async def multi_handle_upload(self, files: list[rx.UploadFile]): """Handle the upload of a file. @@ -122,14 +115,11 @@ async def bg_upload(self, files: list[rx.UploadFile]): Args: files: The uploaded files. """ - pass class FileStateBase2(FileStateBase1): """The parent state for a grandchild FileUploadState.""" - pass - class GrandChildFileUploadState(FileStateBase2): """The child state for uploading a file.""" @@ -143,7 +133,6 @@ async def handle_upload2(self, files): Args: files: The uploaded files. """ - pass async def multi_handle_upload(self, files: list[rx.UploadFile]): """Handle the upload of a file. @@ -169,4 +158,3 @@ async def bg_upload(self, files: list[rx.UploadFile]): Args: files: The uploaded files. """ - pass diff --git a/tests/units/test_app.py b/tests/units/test_app.py index da8b5ee7508..0372ff73af0 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -68,8 +68,6 @@ class EmptyState(BaseState): """An empty state.""" - pass - @pytest.fixture def index_page() -> ComponentCallable: @@ -105,7 +103,7 @@ class ATestState(BaseState): var: int -@pytest.fixture() +@pytest.fixture def test_state() -> type[BaseState]: """A default state. @@ -115,7 +113,7 @@ def test_state() -> type[BaseState]: return ATestState -@pytest.fixture() +@pytest.fixture def redundant_test_state() -> type[BaseState]: """A default state. @@ -154,12 +152,10 @@ def test_model_auth() -> type[Model]: class TestModelAuth(Model, table=True): """A test model with auth.""" - pass - return TestModelAuth -@pytest.fixture() +@pytest.fixture def test_get_engine(): """A default database engine. @@ -175,7 +171,7 @@ def test_get_engine(): ) -@pytest.fixture() +@pytest.fixture def test_custom_auth_admin() -> type[AuthProvider]: """A default auth provider. @@ -191,19 +187,15 @@ class TestAuthProvider(AuthProvider): def login(self): # pyright: ignore [reportIncompatibleMethodOverride] """Login.""" - pass def is_authenticated(self): # pyright: ignore [reportIncompatibleMethodOverride] """Is authenticated.""" - pass def get_admin_user(self): # pyright: ignore [reportIncompatibleMethodOverride] """Get admin user.""" - pass def logout(self): # pyright: ignore [reportIncompatibleMethodOverride] """Logout.""" - pass return TestAuthProvider @@ -1594,7 +1586,7 @@ def page2(): assert str(third_text.children[0].contents) == '"third"' -@pytest.mark.parametrize("export", (True, False)) +@pytest.mark.parametrize("export", [True, False]) def test_app_with_transpile_packages(compilable_app: tuple[App, Path], export: bool): class C1(rx.Component): library = "foo@1.2.3" @@ -1736,7 +1728,7 @@ def handle(self, exception: Exception): @pytest.mark.parametrize( - "handler_fn, expected", + ("handler_fn", "expected"), [ pytest.param( custom_exception_handlers["partial"], @@ -1799,7 +1791,7 @@ def backend_exception_handler_with_wrong_return_type(exception: Exception) -> in @pytest.mark.parametrize( - "handler_fn, expected", + ("handler_fn", "expected"), [ pytest.param( backend_exception_handler_with_wrong_return_type, diff --git a/tests/units/test_attribute_access_type.py b/tests/units/test_attribute_access_type.py index 370af6f50a1..7f851791b30 100644 --- a/tests/units/test_attribute_access_type.py +++ b/tests/units/test_attribute_access_type.py @@ -300,12 +300,12 @@ class AttrClass: count: int = 0 name: str = "test" - int_list: list[int] = [] - str_list: list[str] = [] + int_list: list[int] = attrs.field(factory=list) + str_list: list[str] = attrs.field(factory=list) optional_int: int | None = None sqla_tag: SQLATag | None = None - labels: list[SQLALabel] = [] - dict_str_str: dict[str, str] = {} + labels: list[SQLALabel] = attrs.field(factory=list) + dict_str_str: dict[str, str] = attrs.field(factory=dict) default_factory: list[int] = attrs.field(factory=list) @property @@ -348,7 +348,7 @@ def first_label(self) -> SQLALabel | None: ], ) @pytest.mark.parametrize( - "attr, expected", + ("attr", "expected"), [ pytest.param("count", int, id="int"), pytest.param("name", str, id="str"), diff --git a/tests/units/test_config.py b/tests/units/test_config.py index a779a1d63b1..1ad3002e9ad 100644 --- a/tests/units/test_config.py +++ b/tests/units/test_config.py @@ -36,7 +36,7 @@ def test_set_app_name(base_config_values): @pytest.mark.parametrize( - "env_var, value", + ("env_var", "value"), [ ("APP_NAME", "my_test_app"), ("FRONTEND_PORT", 3001), @@ -96,7 +96,7 @@ def test_update_from_env_path( @pytest.mark.parametrize( - "kwargs, expected", + ("kwargs", "expected"), [ ( {"app_name": "test_app", "api_url": "http://example.com"}, @@ -243,7 +243,7 @@ def test_interpret_int_env() -> None: assert interpret_int_env("3001", "FRONTEND_PORT") == 3001 -@pytest.mark.parametrize("value, expected", [("true", True), ("false", False)]) +@pytest.mark.parametrize(("value", "expected"), [("true", True), ("false", False)]) def test_interpret_bool_env(value: str, expected: bool) -> None: assert interpret_boolean_env(value, "TELEMETRY_ENABLED") == expected diff --git a/tests/units/test_db_config.py b/tests/units/test_db_config.py index 5b716e6bbad..d778098713a 100644 --- a/tests/units/test_db_config.py +++ b/tests/units/test_db_config.py @@ -6,7 +6,7 @@ @pytest.mark.parametrize( - "engine,username,password,host,port,database,expected_url", + ("engine", "username", "password", "host", "port", "database", "expected_url"), [ ( "postgresql", @@ -116,7 +116,7 @@ def test_constructor_sqlite(): @pytest.mark.parametrize( - "username,password,host,port,database,expected_url", + ("username", "password", "host", "port", "database", "expected_url"), [ ( "user", @@ -156,7 +156,7 @@ def test_constructor_postgresql(username, password, host, port, database, expect @pytest.mark.parametrize( - "username,password,host,port,database,expected_url", + ("username", "password", "host", "port", "database", "expected_url"), [ ( "user", diff --git a/tests/units/test_event.py b/tests/units/test_event.py index 50411f05899..4732957b8e3 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -46,10 +46,10 @@ def test_fn(): test_fn.__qualname__ = "test_fn" - def test_fn_with_args(_, arg1, arg2): + def fn_with_args(_, arg1, arg2): pass - test_fn_with_args.__qualname__ = "test_fn_with_args" + fn_with_args.__qualname__ = "fn_with_args" handler = EventHandler(fn=test_fn) event_spec = handler() @@ -58,7 +58,7 @@ def test_fn_with_args(_, arg1, arg2): assert event_spec.args == () assert format.format_event(event_spec) == 'Event("test_fn", {})' - handler = EventHandler(fn=test_fn_with_args) + handler = EventHandler(fn=fn_with_args) event_spec = handler(make_var("first"), make_var("second")) # Test passing vars as args. @@ -69,22 +69,22 @@ def test_fn_with_args(_, arg1, arg2): assert event_spec.args[1][1].equals(Var(_js_expr="second")) assert ( format.format_event(event_spec) - == 'Event("test_fn_with_args", {arg1:first,arg2:second})' + == 'Event("fn_with_args", {arg1:first,arg2:second})' ) # Passing args as strings should format differently. event_spec = handler("first", "second") assert ( format.format_event(event_spec) - == 'Event("test_fn_with_args", {arg1:"first",arg2:"second"})' + == 'Event("fn_with_args", {arg1:"first",arg2:"second"})' ) first, second = 123, "456" - handler = EventHandler(fn=test_fn_with_args) + handler = EventHandler(fn=fn_with_args) event_spec = handler(first, second) assert ( format.format_event(event_spec) - == 'Event("test_fn_with_args", {arg1:123,arg2:"456"})' + == 'Event("fn_with_args", {arg1:123,arg2:"456"})' ) assert event_spec.handler == handler @@ -93,7 +93,7 @@ def test_fn_with_args(_, arg1, arg2): assert event_spec.args[1][0].equals(Var(_js_expr="arg2")) assert event_spec.args[1][1].equals(LiteralVar.create(second)) - handler = EventHandler(fn=test_fn_with_args) + handler = EventHandler(fn=fn_with_args) with pytest.raises(TypeError): handler(test_fn) @@ -101,15 +101,15 @@ def test_fn_with_args(_, arg1, arg2): def test_call_event_handler_partial(): """Calling an EventHandler with incomplete args returns an EventSpec that can be extended.""" - def test_fn_with_args(_, arg1, arg2): + def fn_with_args(_, arg1, arg2): pass - test_fn_with_args.__qualname__ = "test_fn_with_args" + fn_with_args.__qualname__ = "fn_with_args" def spec(a2: Var[str]) -> list[Var[str]]: return [a2] - handler = EventHandler(fn=test_fn_with_args, state_full_name="BigState") + handler = EventHandler(fn=fn_with_args, state_full_name="BigState") event_spec = handler(make_var("first")) event_spec2 = call_event_handler(event_spec, spec) @@ -119,7 +119,7 @@ def spec(a2: Var[str]) -> list[Var[str]]: assert event_spec.args[0][1].equals(Var(_js_expr="first")) assert ( format.format_event(event_spec) - == 'Event("BigState.test_fn_with_args", {arg1:first})' + == 'Event("BigState.fn_with_args", {arg1:first})' ) assert event_spec2 is not event_spec @@ -131,17 +131,17 @@ def spec(a2: Var[str]) -> list[Var[str]]: assert event_spec2.args[1][1].equals(Var(_js_expr="_a2", _var_type=str)) assert ( format.format_event(event_spec2) - == 'Event("BigState.test_fn_with_args", {arg1:first,arg2:_a2})' + == 'Event("BigState.fn_with_args", {arg1:first,arg2:_a2})' ) @pytest.mark.parametrize( ("arg1", "arg2"), - ( + [ (1, 2), (1, "2"), ({"a": 1}, {"b": 2}), - ), + ], ) def test_fix_events(arg1, arg2): """Test that chaining an event handler with args formats the payload correctly. @@ -151,21 +151,21 @@ def test_fix_events(arg1, arg2): arg2: The second arg passed to the handler. """ - def test_fn_with_args(_, arg1, arg2): + def fn_with_args(_, arg1, arg2): pass - test_fn_with_args.__qualname__ = "test_fn_with_args" + fn_with_args.__qualname__ = "fn_with_args" - handler = EventHandler(fn=test_fn_with_args) + handler = EventHandler(fn=fn_with_args) event_spec = handler(arg1, arg2) event = fix_events([event_spec], token="foo")[0] - assert event.name == test_fn_with_args.__qualname__ + assert event.name == fn_with_args.__qualname__ assert event.token == "foo" assert event.payload == {"arg1": arg1, "arg2": arg2} @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ ( ("/path", None, None), diff --git a/tests/units/test_health_endpoint.py b/tests/units/test_health_endpoint.py index d0dff4f2bc3..b293df808b2 100644 --- a/tests/units/test_health_endpoint.py +++ b/tests/units/test_health_endpoint.py @@ -13,7 +13,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize( - "mock_redis_client, expected_status", + ("mock_redis_client", "expected_status"), [ # Case 1: Redis client is available and responds to ping (Mock(ping=lambda: None), {"redis": True}), @@ -41,7 +41,7 @@ async def test_get_redis_status( @pytest.mark.asyncio @pytest.mark.parametrize( - "mock_engine, execute_side_effect, expected_status", + ("mock_engine", "execute_side_effect", "expected_status"), [ # Case 1: Database is accessible (MagicMock(), None, {"db": True}), @@ -79,7 +79,14 @@ async def test_get_db_status( @pytest.mark.asyncio @pytest.mark.parametrize( - "db_enabled, redis_enabled, db_status, redis_status, expected_status, expected_code", + ( + "db_enabled", + "redis_enabled", + "db_status", + "redis_status", + "expected_status", + "expected_code", + ), [ # Case 1: Both services are connected (True, True, True, True, {"status": True, "db": True, "redis": True}, 200), diff --git a/tests/units/test_prerequisites.py b/tests/units/test_prerequisites.py index 46cecd68fbd..1a162f7686b 100644 --- a/tests/units/test_prerequisites.py +++ b/tests/units/test_prerequisites.py @@ -22,7 +22,7 @@ @pytest.mark.parametrize( - "config, export, expected_output", + ("config", "export", "expected_output"), [ ( Config( @@ -88,7 +88,7 @@ def test_update_next_config(config, export, expected_output): @pytest.mark.parametrize( ("transpile_packages", "expected_transpile_packages"), - ( + [ ( ["foo", "@bar/baz"], ["@bar/baz", "foo"], @@ -99,7 +99,7 @@ def test_update_next_config(config, export, expected_output): ), (["@bar/baz", {"name": "foo"}], ["@bar/baz", "foo"]), (["@bar/baz", {"name": "@foo/baz"}], ["@bar/baz", "@foo/baz"]), - ), + ], ) def test_transpile_packages(transpile_packages, expected_transpile_packages): output = _update_next_config( @@ -178,7 +178,7 @@ def temp_directory(): @pytest.mark.parametrize( - "config_code,expected", + ("config_code", "expected"), [ ("rx.Config(app_name='old_name')", 'rx.Config(app_name="new_name")'), ('rx.Config(app_name="old_name")', 'rx.Config(app_name="new_name")'), diff --git a/tests/units/test_route.py b/tests/units/test_route.py index ecb252410fd..50b3423b4f6 100644 --- a/tests/units/test_route.py +++ b/tests/units/test_route.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize( - "route_name, expected", + ("route_name", "expected"), [ ("/users/[id]", {"id": constants.RouteArgType.SINGLE}), ( @@ -36,7 +36,7 @@ def test_invalid_route_args(route_name): @pytest.mark.parametrize( - "route_name,expected", + ("route_name", "expected"), [ ("/events/[year]/[month]/[...slug]", "[...slug]"), ("pages/shop/[[...slug]]", "[[...slug]]"), @@ -73,13 +73,13 @@ def test_verify_invalid_routes(route_name): verify_route_validity(route_name) -@pytest.fixture() +@pytest.fixture def app(): return App() @pytest.mark.parametrize( - "route1,route2", + ("route1", "route2"), [ ("/posts/[slug]", "/posts/[slug1]"), ("/posts/[slug]/info", "/posts/[slug1]/info1"), @@ -96,7 +96,7 @@ def test_check_routes_conflict_invalid(mocker: MockerFixture, app, route1, route @pytest.mark.parametrize( - "route1,route2", + ("route1", "route2"), [ ("/posts/[slug]", "/post/[slug1]"), ("/posts/[slug]", "/post/[slug]"), diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 1961fd77a73..0163fccde8b 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -143,7 +143,6 @@ def upper(self) -> str: def do_something(self): """Do something.""" - pass async def set_asynctest(self, value: int): """Set the asynctest value. Intentionally overwrite the default setter with an async one. @@ -190,7 +189,6 @@ class GrandchildState(ChildState): def do_nothing(self): """Do something.""" - pass class GrandchildState2(ChildState2): @@ -1662,7 +1660,7 @@ async def state_manager(request) -> AsyncGenerator[StateManager, None]: await state_manager.close() -@pytest.fixture() +@pytest.fixture def substate_token(state_manager, token) -> str: """A token + substate name for looking up in state manager. @@ -1764,7 +1762,7 @@ async def state_manager_redis() -> AsyncGenerator[StateManager, None]: await state_manager.close() -@pytest.fixture() +@pytest.fixture def substate_token_redis(state_manager_redis, token): """A token + substate name for looking up in state manager. @@ -1818,16 +1816,17 @@ async def test_state_manager_lock_expire_contend( state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD order = [] + waiter_event = asyncio.Event() async def _coro_blocker(): async with state_manager_redis.modify_state(substate_token_redis) as state: order.append("blocker") + waiter_event.set() await asyncio.sleep(LOCK_EXPIRE_SLEEP) state.num1 = unexp_num1 async def _coro_waiter(): - while "blocker" not in order: - await asyncio.sleep(0.005) + await waiter_event.wait() async with state_manager_redis.modify_state(substate_token_redis) as state: order.append("waiter") assert state.num1 != unexp_num1 @@ -1899,7 +1898,7 @@ def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs) -@pytest.fixture(scope="function") +@pytest.fixture def mock_app_simple(monkeypatch) -> rx.App: """Simple Mock app fixture. @@ -1924,7 +1923,7 @@ def _mock_get_app(*args, **kwargs): return app -@pytest.fixture(scope="function") +@pytest.fixture def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App: """Mock app fixture. @@ -2586,10 +2585,10 @@ def assert_custom_dirty(): @pytest.mark.parametrize( - ("copy_func",), + "copy_func", [ - (copy.copy,), - (copy.deepcopy,), + copy.copy, + copy.deepcopy, ], ) def test_mutable_copy(mutable_state: MutableTestState, copy_func: Callable): @@ -2613,10 +2612,10 @@ def test_mutable_copy(mutable_state: MutableTestState, copy_func: Callable): @pytest.mark.parametrize( - ("copy_func",), + "copy_func", [ - (copy.copy,), - (copy.deepcopy,), + copy.copy, + copy.deepcopy, ], ) def test_mutable_copy_vars(mutable_state: MutableTestState, copy_func: Callable): @@ -2637,9 +2636,9 @@ def test_mutable_copy_vars(mutable_state: MutableTestState, copy_func: Callable) def test_duplicate_substate_class(mocker: MockerFixture): # Neuter pytest escape hatch, because we want to test duplicate detection. - mocker.patch("reflex.state.is_testing_env", lambda: False) + mocker.patch("reflex.state.is_testing_env", return_value=False) # Neuter state handling since these _are_ defined inside a function. - mocker.patch("reflex.state.BaseState._handle_local_def", lambda: None) + mocker.patch("reflex.state.BaseState._handle_local_def", return_value=False) with pytest.raises(ValueError): class TestState(BaseState): @@ -2885,7 +2884,7 @@ async def test_handler(self): @pytest.mark.asyncio @pytest.mark.parametrize( - "test_state, expected", + ("test_state", "expected"), [ (OnLoadState, {"on_load_state": {"num": 1}}), (OnLoadState2, {"on_load_state2": {"num": 1}}), @@ -3136,13 +3135,9 @@ class Parent(BaseState): class Child(Parent): """A state simulating UpdateVarsInternalState.""" - pass - class Child2(Parent): """An unconnected child state.""" - pass - class Child3(Parent): """A child state with a computed var causing it to be pre-fetched. @@ -3162,16 +3157,12 @@ class Grandchild3(Child3): invalid parent state names were being constructed. """ - pass - class GreatGrandchild3(Grandchild3): """Fetching this state wants to also fetch Child3 as a missing parent. However, Child3 should already be cached in the state tree because it has a computed var. """ - pass - mock_app.state_manager.state = mock_app._state = Parent # Get the top level state via unconnected sibling. @@ -3240,8 +3231,6 @@ async def test_router_var_dep(state_manager: StateManager, token: str) -> None: class RouterVarParentState(State): """A parent state for testing router var dependency.""" - pass - class RouterVarDepState(RouterVarParentState): """A state with a router var dependency.""" @@ -3326,7 +3315,7 @@ async def test_setvar_async_setter(): reason="Test requires redis", ) @pytest.mark.parametrize( - "expiration_kwargs, expected_values", + ("expiration_kwargs", "expected_values"), [ ( {"redis_lock_expiration": 20000}, @@ -3402,7 +3391,7 @@ def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_ reason="Test requires redis", ) @pytest.mark.parametrize( - "redis_lock_expiration, redis_lock_warning_threshold", + ("redis_lock_expiration", "redis_lock_warning_threshold"), [ (10000, 10000), (20000, 30000), @@ -3485,26 +3474,18 @@ def computed(self) -> str: class UsesMixinState(MixinState, State): """A state that uses the mixin state.""" - pass - class ChildUsesMixinState(UsesMixinState): """A child state that uses the mixin state.""" - pass - class ChildMixinState(ChildUsesMixinState, mixin=True): """A mixin state that inherits from a concrete state that uses mixins.""" - pass - class GrandchildUsesMixinState(ChildMixinState): """A grandchild state that uses the mixin state.""" - pass - class BareMixin: """A bare mixin which does not inherit from rx.State.""" @@ -3515,20 +3496,14 @@ class BareMixin: class BareStateMixin(BareMixin, rx.State, mixin=True): """A state mixin that uses a bare mixin.""" - pass - class BareMixinState(BareStateMixin, State): """A state that uses a bare mixin.""" - pass - class ChildBareMixinState(BareMixinState): """A child state that uses a bare mixin.""" - pass - def test_mixin_state() -> None: """Test that a mixin state works correctly.""" @@ -3963,8 +3938,6 @@ class Parent(BaseState): class Child2(Parent): """An unconnected child state.""" - pass - class Child3(Parent): """A child state with a computed var causing it to be pre-fetched. diff --git a/tests/units/test_style.py b/tests/units/test_style.py index 59742c46dd3..de479c2b987 100644 --- a/tests/units/test_style.py +++ b/tests/units/test_style.py @@ -46,7 +46,7 @@ @pytest.mark.parametrize( - "style_dict,expected", + ("style_dict", "expected"), test_style, ) def test_convert(style_dict, expected): @@ -61,7 +61,7 @@ def test_convert(style_dict, expected): @pytest.mark.parametrize( - "style_dict,expected", + ("style_dict", "expected"), test_style, ) def test_create_style(style_dict, expected): diff --git a/tests/units/test_var.py b/tests/units/test_var.py index 326a4ab354d..f7562c7acf7 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -165,7 +165,8 @@ def StateWithRuntimeOnlyVar(): class StateWithRuntimeOnlyVar(BaseState): @computed_var(initial_value=None) def var_raises_at_runtime(self) -> str: - raise ValueError("So nicht, mein Freund") + msg = "So nicht, mein Freund" + raise ValueError(msg) return StateWithRuntimeOnlyVar @@ -175,13 +176,14 @@ def ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar): class ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar): @computed_var(initial_value="Initial value") def var_raises_at_runtime_child(self) -> str: - raise ValueError("So nicht, mein Freund") + msg = "So nicht, mein Freund" + raise ValueError(msg) return ChildWithRuntimeOnlyVar @pytest.mark.parametrize( - "prop,expected", + ("prop", "expected"), zip( test_vars, [ @@ -205,7 +207,7 @@ def test_full_name(prop, expected): @pytest.mark.parametrize( - "prop,expected", + ("prop", "expected"), zip( test_vars, ["prop1", "key", "state.value", "state.local", "local2"], @@ -246,7 +248,7 @@ def test_default_value(prop: Var, expected): @pytest.mark.parametrize( - "prop,expected", + ("prop", "expected"), zip( test_vars, [ @@ -270,7 +272,7 @@ def test_get_setter(prop: Var, expected): @pytest.mark.parametrize( - "value,expected", + ("value", "expected"), [ (None, Var(_js_expr="null", _var_type=None)), (1, Var(_js_expr="1", _var_type=int)), @@ -378,7 +380,7 @@ def test_basic_operations(TestObj): @pytest.mark.parametrize( - "var, expected", + ("var", "expected"), [ (v([1, 2, 3]), "[1, 2, 3]"), (v({1, 2, 3}), "[1, 2, 3]"), @@ -428,7 +430,6 @@ class Bar(rx.Base): [ (Var(_js_expr="").to(Foo | Bar), Foo | Bar), (Var(_js_expr="").to(Foo | Bar).bar, int | str), - (Var(_js_expr="").to(Foo | Bar), Foo | Bar), (Var(_js_expr="").to(Foo | Bar).baz, str), ( Var(_js_expr="").to(Foo | Bar).foo, @@ -441,7 +442,7 @@ def test_var_types(var, var_type): @pytest.mark.parametrize( - "var, expected", + ("var", "expected"), [ (v("123"), json.dumps("123")), (Var(_js_expr="foo")._var_set_state("state").to(str), "state.foo"), @@ -462,7 +463,7 @@ def test_str_contains(var, expected): @pytest.mark.parametrize( - "var, expected", + ("var", "expected"), [ (v({"a": 1, "b": 2}), '({ ["a"] : 1, ["b"] : 2 })'), (Var(_js_expr="foo")._var_set_state("state").to(dict), "state.foo"), @@ -505,7 +506,7 @@ def test_var_indexing_lists(var): @pytest.mark.parametrize( - "var, type_", + ("var", "type_"), [ (Var(_js_expr="list", _var_type=list[int]).guess_type(), [int, int]), ( @@ -567,7 +568,7 @@ def test_var(state) -> int: @pytest.mark.parametrize( - "var, index", + ("var", "index"), [ (Var(_js_expr="lst", _var_type=list[int]).guess_type(), [1, 2]), ( @@ -691,7 +692,7 @@ def test_dict_indexing(): @pytest.mark.parametrize( - "var, index", + ("var", "index"), [ ( Var(_js_expr="dict", _var_type=dict[str, str]).guess_type(), @@ -839,7 +840,13 @@ def test_computed_var_with_annotation_error(request, fixture): @pytest.mark.parametrize( - "fixture,var_name,expected_initial,expected_runtime,raises_at_runtime", + ( + "fixture", + "var_name", + "expected_initial", + "expected_runtime", + "raises_at_runtime", + ), [ ( "StateWithInitialComputedVar", @@ -1051,7 +1058,7 @@ def test_index_operation(): @pytest.mark.parametrize( - "var, expected_js", + ("var", "expected_js"), [ (Var.create(float("inf")), "Infinity"), (Var.create(-float("inf")), "-Infinity"), @@ -1202,7 +1209,8 @@ def test_retrival(): result_var_data = LiteralVar.create(f_string)._get_all_var_data() result_immutable_var_data = Var(_js_expr=f_string)._var_data - assert result_var_data is not None and result_immutable_var_data is not None + assert result_var_data is not None + assert result_immutable_var_data is not None assert ( result_var_data.state == result_immutable_var_data.state @@ -1262,7 +1270,7 @@ def test_fstring_concat(): @pytest.mark.parametrize( - "out, expected", + ("out", "expected"), [ (f"{var}", f"{hash(var)}var"), ( @@ -1394,7 +1402,7 @@ def test_unsupported_default_contains(): @pytest.mark.parametrize( - "operand1_var,operand2_var,operators", + ("operand1_var", "operand2_var", "operators"), [ ( LiteralVar.create(10), @@ -1502,7 +1510,7 @@ def test_valid_var_operations(operand1_var: Var, operand2_var, operators: list[s @pytest.mark.parametrize( - "operand1_var,operand2_var,operators", + ("operand1_var", "operand2_var", "operators"), [ ( LiteralVar.create(10), @@ -1781,7 +1789,7 @@ def test_invalid_var_operations(operand1_var: Var, operand2_var, operators: list @pytest.mark.parametrize( - "var, expected", + ("var", "expected"), [ (LiteralVar.create("string_value"), '"string_value"'), (LiteralVar.create(1), "1"), @@ -1810,7 +1818,7 @@ def cv_fget(state: BaseState) -> int: @pytest.mark.parametrize( - "deps,expected", + ("deps", "expected"), [ (["a"], {None: {"a"}}), (["b"], {None: {"b"}}), diff --git a/tests/units/utils/test_format.py b/tests/units/utils/test_format.py index e2387c60d87..a47ce3fb713 100644 --- a/tests/units/utils/test_format.py +++ b/tests/units/utils/test_format.py @@ -8,7 +8,13 @@ import pytest from reflex.components.tags.tag import Tag -from reflex.event import EventChain, EventHandler, EventSpec, JavascriptInputEvent +from reflex.event import ( + EventChain, + EventHandler, + EventSpec, + JavascriptInputEvent, + no_args_event_spec, +) from reflex.style import Style from reflex.utils import format from reflex.utils.serializers import serialize_figure @@ -31,7 +37,7 @@ def mock_event(arg): @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ ("{", "}"), ("(", ")"), @@ -52,7 +58,7 @@ def test_get_close_char(input: str, output: str): @pytest.mark.parametrize( - "text,open,expected", + ("text", "open", "expected"), [ ("", "{", False), ("{wrap}", "{", True), @@ -73,7 +79,7 @@ def test_is_wrapped(text: str, open: str, expected: bool): @pytest.mark.parametrize( - "text,open,check_first,num,expected", + ("text", "open", "check_first", "num", "expected"), [ ("", "{", True, 1, "{}"), ("wrap", "{", True, 1, "{wrap}"), @@ -99,17 +105,13 @@ def test_wrap(text: str, open: str, expected: str, check_first: bool, num: int): @pytest.mark.parametrize( - "string,expected_output", + ("string", "expected_output"), [ ("This is a random string", "This is a random string"), ( "This is a random string with `backticks`", "This is a random string with \\`backticks\\`", ), - ( - "This is a random string with `backticks`", - "This is a random string with \\`backticks\\`", - ), ( "This is a string with ${someValue[`string interpolation`]} unescaped", "This is a string with ${someValue[`string interpolation`]} unescaped", @@ -129,7 +131,7 @@ def test_escape_js_string(string, expected_output): @pytest.mark.parametrize( - "text,indent_level,expected", + ("text", "indent_level", "expected"), [ ("", 2, ""), ("hello", 2, "hello"), @@ -153,7 +155,7 @@ def test_indent(text: str, indent_level: int, expected: str, windows_platform: b @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ ("", ""), ("hello", "hello"), @@ -179,7 +181,7 @@ def test_to_snake_case(input: str, output: str): @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ ("", ""), ("hello", "hello"), @@ -209,7 +211,7 @@ def test_to_camel_case(input: str, output: str): @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ ("", ""), ("hello", "Hello"), @@ -229,7 +231,7 @@ def test_to_title_case(input: str, output: str): @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ ("", ""), ("hello", "hello"), @@ -253,7 +255,7 @@ def test_to_kebab_case(input: str, output: str): @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ ("", "{``}"), ("hello", "{`hello`}"), @@ -272,7 +274,7 @@ def test_format_string(input: str, output: str): @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ (LiteralVar.create(value="test"), '"test"'), (Var(_js_expr="test"), "test"), @@ -283,7 +285,7 @@ def test_format_var(input: Var, output: str): @pytest.mark.parametrize( - "route,format_case,expected", + ("route", "format_case", "expected"), [ ("", True, "index"), ("/", True, "index"), @@ -311,7 +313,7 @@ def test_format_route(route: str, format_case: bool, expected: bool): @pytest.mark.parametrize( - "condition, match_cases, default,expected", + ("condition", "match_cases", "default", "expected"), [ ( "state__state.value", @@ -350,7 +352,7 @@ def test_format_match( @pytest.mark.parametrize( - "prop,formatted", + ("prop", "formatted"), [ ("string", '"string"'), ("{wrapped_string}", '"{wrapped_string}"'), @@ -372,7 +374,7 @@ def test_format_match( ( EventChain( events=[EventSpec(handler=EventHandler(fn=mock_event))], - args_spec=lambda: [], + args_spec=no_args_event_spec, ), '((...args) => (addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ }))))', ), @@ -400,7 +402,7 @@ def test_format_match( ( EventChain( events=[EventSpec(handler=EventHandler(fn=mock_event))], - args_spec=lambda: [], + args_spec=no_args_event_spec, event_actions={"stopPropagation": True}, ), '((...args) => (addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["stopPropagation"] : true }))))', @@ -413,14 +415,14 @@ def test_format_match( event_actions={"stopPropagation": True}, ) ], - args_spec=lambda: [], + args_spec=no_args_event_spec, ), '((...args) => (addEvents([(Event("mock_event", ({ }), ({ ["stopPropagation"] : true })))], args, ({ }))))', ), ( EventChain( events=[EventSpec(handler=EventHandler(fn=mock_event))], - args_spec=lambda: [], + args_spec=no_args_event_spec, event_actions={"preventDefault": True}, ), '((...args) => (addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["preventDefault"] : true }))))', @@ -472,7 +474,7 @@ def test_format_prop(prop: Var, formatted: str): @pytest.mark.parametrize( - "single_props,key_value_props,output", + ("single_props", "key_value_props", "output"), [ ( [Var(_js_expr="props")], @@ -493,7 +495,7 @@ def test_format_props(single_props, key_value_props, output): @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ (EventHandler(fn=mock_event), ("", "mock_event")), ], @@ -503,7 +505,7 @@ def test_get_handler_parts(input, output): @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ (TestState.do_something, f"{TestState.get_full_name()}.do_something"), ( @@ -527,7 +529,7 @@ def test_format_event_handler(input, output): @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ ( EventSpec(handler=EventHandler(fn=mock_event)), @@ -540,7 +542,7 @@ def test_format_event(input, output): @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ ({"query": {"k1": 1, "k2": 2}}, {"k1": 1, "k2": 2}), ({"query": {"k1": 1, "k-2": 2}}, {"k1": 1, "k_2": 2}), @@ -580,7 +582,7 @@ def test_format_query_params(input, output): @pytest.mark.parametrize( - "input, output", + ("input", "output"), [ ( TestState(_reflex_internal_init=True).dict(), # pyright: ignore [reportCallIssue] @@ -640,7 +642,7 @@ def test_format_state(input, output): @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ ("input1", "ref_input1"), ("input 1", "ref_input_1"), @@ -660,7 +662,7 @@ def test_format_ref(input, output): @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ (("my_array", None), "refs_my_array"), (("my_array", LiteralVar.create(0)), "refs_my_array[0]"), @@ -672,7 +674,7 @@ def test_format_array_ref(input, output): @pytest.mark.parametrize( - "input, output", + ("input", "output"), [ ("library@^0.1.2", "library"), ("library", "library"), @@ -691,7 +693,7 @@ def test_format_library_name(input: str, output: str): @pytest.mark.parametrize( - "input,output", + ("input", "output"), [ (None, "null"), (True, "true"), diff --git a/tests/units/utils/test_imports.py b/tests/units/utils/test_imports.py index c30d1d85c75..cfbeff8dc81 100644 --- a/tests/units/utils/test_imports.py +++ b/tests/units/utils/test_imports.py @@ -10,7 +10,7 @@ @pytest.mark.parametrize( - "import_var, expected_name", + ("import_var", "expected_name"), [ ( ImportVar(tag="BaseTag"), @@ -49,7 +49,7 @@ def test_import_var(import_var, expected_name): @pytest.mark.parametrize( - "input_1, input_2, output", + ("input_1", "input_2", "output"), [ ( {"react": {"Component"}}, @@ -89,7 +89,7 @@ def test_merge_imports(input_1, input_2, output): @pytest.mark.parametrize( - "input, output", + ("input", "output"), [ ({}, {}), ( diff --git a/tests/units/utils/test_serializers.py b/tests/units/utils/test_serializers.py index c05ffbf9319..970b355cbf9 100644 --- a/tests/units/utils/test_serializers.py +++ b/tests/units/utils/test_serializers.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize( - "type_,expected", + ("type_", "expected"), [(Enum, True)], ) def test_has_serializer(type_: type, expected: bool): @@ -29,7 +29,7 @@ def test_has_serializer(type_: type, expected: bool): @pytest.mark.parametrize( - "type_,expected", + ("type_", "expected"), [ (datetime.datetime, serializers.serialize_datetime), (datetime.date, serializers.serialize_datetime), @@ -121,7 +121,7 @@ class BaseSubclass(Base): @pytest.mark.parametrize( - "value,expected", + ("value", "expected"), [ ("test", "test"), (1, 1), @@ -205,7 +205,7 @@ def test_serialize(value: Any, expected: str): @pytest.mark.parametrize( - "value,expected,exp_var_is_string", + ("value", "expected", "exp_var_is_string"), [ ("test", '"test"', False), (1, "1", False), diff --git a/tests/units/utils/test_types.py b/tests/units/utils/test_types.py index 3746e4575e4..340861e06e9 100644 --- a/tests/units/utils/test_types.py +++ b/tests/units/utils/test_types.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize( - "params, allowed_value_str, value_str", + ("params", "allowed_value_str", "value_str"), [ (["size", 1, Literal["1", "2", "3"], "Heading"], "'1','2','3'", "1"), (["size", "1", Literal[1, 2, 3], "Heading"], "1,2,3", "'1'"), @@ -24,7 +24,7 @@ def test_validate_literal_error_msg(params, allowed_value_str, value_str): @pytest.mark.parametrize( - "cls,cls_check,expected", + ("cls", "cls_check", "expected"), [ (int, Any, True), (tuple[int], Any, True), @@ -51,29 +51,21 @@ def test_issubclass( class CustomDict(dict[str, str]): """A custom dict with generic arguments.""" - pass - class ChildCustomDict(CustomDict): """A child of CustomDict.""" - pass - class GenericDict(dict): """A generic dict with no generic arguments.""" - pass - class ChildGenericDict(GenericDict): """A child of GenericDict.""" - pass - @pytest.mark.parametrize( - "cls,expected", + ("cls", "expected"), [ (int, False), (str, False), diff --git a/tests/units/utils/test_utils.py b/tests/units/utils/test_utils.py index 47ce07fac5d..216f900d5d5 100644 --- a/tests/units/utils/test_utils.py +++ b/tests/units/utils/test_utils.py @@ -25,7 +25,6 @@ class ExampleTestState(BaseState): def test_event_handler(self): """Test event handler.""" - pass def test_func(): @@ -33,7 +32,7 @@ def test_func(): @pytest.mark.parametrize( - "cls,expected", + ("cls", "expected"), [ (str, False), (int, False), @@ -70,8 +69,6 @@ def test_is_generic_alias(cls: type, expected: bool): (int, bool, False), (list, list, True), (list, list[str], True), # this is wrong, but it's a limitation of the function - (list, list, True), - (list[int], list, True), (list[int], list, True), (list[int], list[str], False), (list[int], list[int], True), @@ -130,8 +127,6 @@ def test_typehint_issubclass(subclass, superclass, expected): (int, bool, False), (list, list, True), (list, list[str], True), # this is wrong, but it's a limitation of the function - (list, list, True), - (list[int], list, True), (list[int], list, True), (list[int], list[str], False), (list[int], list[int], True), @@ -283,7 +278,7 @@ def _cached_hidden_property(self): @pytest.mark.parametrize( - "input, output", + ("input", "output"), [ ("_classvar", False), ("_class_method", False), @@ -302,7 +297,7 @@ def test_is_backend_base_variable( @pytest.mark.parametrize( - "cls, cls_check, expected", + ("cls", "cls_check", "expected"), [ (int, int, True), (int, float, False), @@ -333,7 +328,7 @@ def test_unsupported_literals(cls: type): @pytest.mark.parametrize( - "app_name,expected_config_name", + ("app_name", "expected_config_name"), [ ("appname", "AppnameConfig"), ("app_name", "AppnameConfig"), @@ -395,18 +390,15 @@ def test_create_config_e2e(tmp_working_dir): class DataFrame: """A Fake pandas DataFrame class.""" - pass - @pytest.mark.parametrize( - "class_type,expected", + ("class_type", "expected"), [ (list, False), (int, False), (dict, False), (DataFrame, True), (typing.Any, False), - (list, False), ], ) def test_is_dataframe(class_type, expected): diff --git a/tests/units/vars/test_base.py b/tests/units/vars/test_base.py index ed48b8ea33e..54b3d48182f 100644 --- a/tests/units/vars/test_base.py +++ b/tests/units/vars/test_base.py @@ -9,26 +9,18 @@ class CustomDict(dict[str, str]): """A custom dict with generic arguments.""" - pass - class ChildCustomDict(CustomDict): """A child of CustomDict.""" - pass - class GenericDict(dict): """A generic dict with no generic arguments.""" - pass - class ChildGenericDict(GenericDict): """A child of GenericDict.""" - pass - @pytest.mark.parametrize( ("value", "expected"), diff --git a/tests/units/vars/test_object.py b/tests/units/vars/test_object.py index 72142c9fc52..6dc27fdb026 100644 --- a/tests/units/vars/test_object.py +++ b/tests/units/vars/test_object.py @@ -40,8 +40,6 @@ class Base(rx.Base): class SqlaBase(DeclarativeBase, MappedAsDataclass): """Sqlalchemy declarative mapping base class.""" - pass - class SqlaModel(SqlaBase): """A sqlalchemy model with a single attribute."""