Skip to content
66 changes: 28 additions & 38 deletions reflex/utils/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import signal
import socket
import subprocess
import sys
from collections.abc import Callable, Generator, Sequence
from concurrent import futures
from contextlib import closing
Expand Down Expand Up @@ -68,12 +69,11 @@ def _can_bind_at_port(
"""
try:
with closing(socket.socket(address_family, socket.SOCK_STREAM)) as sock:
if sys.platform != "win32":
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((address, port))
except OverflowError:
return False
except PermissionError:
return False
except OSError:
except (OverflowError, PermissionError, OSError) as e:
console.warn(f"Unable to bind to {address}:{port} due to: {e}.")
return False
return True

Expand All @@ -87,38 +87,13 @@ def is_process_on_port(port: int) -> bool:
Returns:
Whether a process is running on the given port.
"""
return not _can_bind_at_port( # Test IPv4 localhost (127.0.0.1)
socket.AF_INET, "127.0.0.1", port
) or not _can_bind_at_port(
socket.AF_INET6, "::1", port
) # Test IPv6 localhost (::1)
return (
not _can_bind_at_port(socket.AF_INET, "", port) # Test IPv4 local network
or not _can_bind_at_port(socket.AF_INET6, "", port) # Test IPv6 local network
)


def change_port(port: int, _type: str) -> int:
"""Change the port.

Args:
port: The port.
_type: The type of the port.

Returns:
The new port.

Raises:
Exit: If the port is invalid or if the new port is occupied.
"""
new_port = port + 1
if new_port < 0 or new_port > 65535:
console.error(
f"The {_type} port: {port} is invalid. It must be between 0 and 65535."
)
raise click.exceptions.Exit(1)
if is_process_on_port(new_port):
return change_port(new_port, _type)
console.info(
f"The {_type} will run on port [bold underline]{new_port}[/bold underline]."
)
return new_port
MAXIMUM_PORT = 2**16 - 1


def handle_port(service_name: str, port: int, auto_increment: bool) -> int:
Expand All @@ -137,13 +112,28 @@ def handle_port(service_name: str, port: int, auto_increment: bool) -> int:
Exit:when the port is in use.
"""
console.debug(f"Checking if {service_name.capitalize()} port: {port} is in use.")

if not is_process_on_port(port):
console.debug(f"{service_name.capitalize()} port: {port} is not in use.")
return port

if auto_increment:
return change_port(port, service_name)
console.error(f"{service_name.capitalize()} port: {port} is already in use.")
raise click.exceptions.Exit
for new_port in range(port + 1, MAXIMUM_PORT + 1):
if not is_process_on_port(new_port):
console.info(
f"The {service_name} will run on port [bold underline]{new_port}[/bold underline]."
)
return new_port
console.debug(
f"{service_name.capitalize()} port: {new_port} is already in use."
)

# If we reach here, it means we couldn't find an available port.
console.error(f"Unable to find an available port for {service_name}")
else:
console.error(f"{service_name.capitalize()} port: {port} is already in use.")

raise click.exceptions.Exit(1)


@overload
Expand Down
52 changes: 20 additions & 32 deletions tests/units/utils/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_is_process_on_port_free_port():
"""Test is_process_on_port returns False when port is free."""
# Find a free port
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.bind(("127.0.0.1", 0))
sock.bind(("", 0))
free_port = sock.getsockname()[1]

# Port should be free after socket is closed
Expand All @@ -26,8 +26,7 @@ def test_is_process_on_port_occupied_port():
"""Test is_process_on_port returns True when port is occupied."""
# Create a server socket to occupy a port
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind(("127.0.0.1", 0))
server_socket.bind(("", 0))
server_socket.listen(1)

occupied_port = server_socket.getsockname()[1]
Expand All @@ -44,8 +43,7 @@ def test_is_process_on_port_ipv6():
# Test with IPv6 socket
try:
server_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind(("::1", 0))
server_socket.bind(("", 0))
server_socket.listen(1)

occupied_port = server_socket.getsockname()[1]
Expand All @@ -64,8 +62,7 @@ def test_is_process_on_port_both_protocols():
"""Test is_process_on_port detects occupation on either IPv4 or IPv6."""
# Create IPv4 server
ipv4_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
ipv4_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
ipv4_socket.bind(("127.0.0.1", 0))
ipv4_socket.bind(("", 0))
ipv4_socket.listen(1)

port = ipv4_socket.getsockname()[1]
Expand Down Expand Up @@ -116,46 +113,37 @@ def test_is_process_on_port_permission_error():
assert result is True


@pytest.mark.parametrize("should_listen", [True, False])
def test_is_process_on_port_concurrent_access(should_listen):
"""Test is_process_on_port works correctly with concurrent access.
def test_is_process_on_port_concurrent_access():
"""Test is_process_on_port works correctly with concurrent access."""
shared = None

Args:
should_listen: Whether the server socket should call listen() or just bind().
"""

def create_server_and_test(port_holder, listen):
def create_server_and_test():
nonlocal shared
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind(("127.0.0.1", 0))
server.bind(("", 0))

if listen:
server.listen(1)
server.listen(1)

port = server.getsockname()[1]
port_holder[0] = port
shared = port

# Small delay to ensure the test runs while server is active
time.sleep(0.1)
server.close()

port_holder = [None]
thread = threading.Thread(
target=create_server_and_test, args=(port_holder, should_listen)
)
thread = threading.Thread(target=create_server_and_test)
thread.start()

# Wait a bit for the server to start
time.sleep(0.05)

if port_holder[0] is not None:
# Port should be occupied while server is running (both bound-only and listening)
assert is_process_on_port(port_holder[0])
assert shared is not None

# Port should be occupied while server is running (both bound-only and listening)
assert is_process_on_port(shared)

thread.join()

# After thread ends and server closes, port should be free
if port_holder[0] is not None:
# Give it a moment for the socket to be fully released
time.sleep(0.1)
assert not is_process_on_port(port_holder[0])
# Give it a moment for the socket to be fully released
time.sleep(0.1)
assert not is_process_on_port(shared)
Loading