Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions reflex/utils/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_num_workers() -> int:
return (os.cpu_count() or 1) * 2 + 1


def _is_address_responsive(
def _can_bind_at_port(
address_family: socket.AddressFamily | int, address: str, port: int
) -> bool:
"""Check if a given address and port are responsive.
Expand All @@ -68,9 +68,14 @@ def _is_address_responsive(
"""
try:
with closing(socket.socket(address_family, socket.SOCK_STREAM)) as sock:
return sock.connect_ex((address, port)) == 0
sock.bind((address, port))
except OverflowError:
return False
except PermissionError:
return False
except OSError:
return False
return True


def is_process_on_port(port: int) -> bool:
Expand All @@ -82,9 +87,9 @@ def is_process_on_port(port: int) -> bool:
Returns:
Whether a process is running on the given port.
"""
return _is_address_responsive( # Test IPv4 localhost (127.0.0.1)
return not _can_bind_at_port( # Test IPv4 localhost (127.0.0.1)
socket.AF_INET, "127.0.0.1", port
) or _is_address_responsive(
) or not _can_bind_at_port(
socket.AF_INET6, "::1", port
) # Test IPv6 localhost (::1)

Expand All @@ -99,8 +104,15 @@ def change_port(port: int, _type: str) -> int:
Returns:
The new port.

Raises:
Exit: If the port is invalid or if the new port is occupied.
"""
new_port = port + 1
if new_port < 0 or new_port > 65535:
console.error(
f"The {_type} port: {port} is invalid. It must be between 0 and 65535."
)
raise click.exceptions.Exit(1)
if is_process_on_port(new_port):
return change_port(new_port, _type)
console.info(
Expand Down
161 changes: 161 additions & 0 deletions tests/units/utils/test_processes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Test process utilities."""

import socket
import threading
import time
from contextlib import closing
from unittest import mock

import pytest

from reflex.utils.processes import is_process_on_port


def test_is_process_on_port_free_port():
"""Test is_process_on_port returns False when port is free."""
# Find a free port
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.bind(("127.0.0.1", 0))
free_port = sock.getsockname()[1]

# Port should be free after socket is closed
assert not is_process_on_port(free_port)


def test_is_process_on_port_occupied_port():
"""Test is_process_on_port returns True when port is occupied."""
# Create a server socket to occupy a port
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind(("127.0.0.1", 0))
server_socket.listen(1)

occupied_port = server_socket.getsockname()[1]

try:
# Port should be occupied
assert is_process_on_port(occupied_port)
finally:
server_socket.close()


def test_is_process_on_port_ipv6():
"""Test is_process_on_port works with IPv6."""
# Test with IPv6 socket
try:
server_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind(("::1", 0))
server_socket.listen(1)

occupied_port = server_socket.getsockname()[1]

try:
# Port should be occupied on IPv6
assert is_process_on_port(occupied_port)
finally:
server_socket.close()
except OSError:
# IPv6 might not be available on some systems
pytest.skip("IPv6 not available on this system")


def test_is_process_on_port_both_protocols():
"""Test is_process_on_port detects occupation on either IPv4 or IPv6."""
# Create IPv4 server
ipv4_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
ipv4_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
ipv4_socket.bind(("127.0.0.1", 0))
ipv4_socket.listen(1)

port = ipv4_socket.getsockname()[1]

try:
# Should detect IPv4 occupation
assert is_process_on_port(port)
finally:
ipv4_socket.close()


@pytest.mark.parametrize("port", [0, 1, 80, 443, 8000, 3000, 65535])
def test_is_process_on_port_various_ports(port):
"""Test is_process_on_port with various port numbers.

Args:
port: The port number to test.
"""
# This test just ensures the function doesn't crash with different port numbers
# The actual result depends on what's running on the system
result = is_process_on_port(port)
assert isinstance(result, bool)


def test_is_process_on_port_mock_socket_error():
"""Test is_process_on_port handles socket errors gracefully."""
with mock.patch("socket.socket") as mock_socket:
mock_socket_instance = mock.MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.__enter__.return_value = mock_socket_instance
mock_socket_instance.bind.side_effect = OSError("Mock socket error")

# Should return True when socket operations fail
result = is_process_on_port(8080)
assert result is True


def test_is_process_on_port_permission_error():
"""Test is_process_on_port handles permission errors."""
with mock.patch("socket.socket") as mock_socket:
mock_socket_instance = mock.MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.__enter__.return_value = mock_socket_instance
mock_socket_instance.bind.side_effect = PermissionError("Permission denied")

# Should return True when permission is denied (can't bind = port is "occupied")
result = is_process_on_port(80)
assert result is True


@pytest.mark.parametrize("should_listen", [True, False])
def test_is_process_on_port_concurrent_access(should_listen):
"""Test is_process_on_port works correctly with concurrent access.

Args:
should_listen: Whether the server socket should call listen() or just bind().
"""

def create_server_and_test(port_holder, listen):
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind(("127.0.0.1", 0))

if listen:
server.listen(1)

port = server.getsockname()[1]
port_holder[0] = port

# Small delay to ensure the test runs while server is active
time.sleep(0.1)
server.close()

port_holder = [None]
thread = threading.Thread(
target=create_server_and_test, args=(port_holder, should_listen)
)
thread.start()

# Wait a bit for the server to start
time.sleep(0.05)

if port_holder[0] is not None:
# Port should be occupied while server is running (both bound-only and listening)
assert is_process_on_port(port_holder[0])

thread.join()

# After thread ends and server closes, port should be free
if port_holder[0] is not None:
# Give it a moment for the socket to be fully released
time.sleep(0.1)
assert not is_process_on_port(port_holder[0])
Loading