WIP: feat: Add router in rollout controller for simpler proxy server usage#1013
WIP: feat: Add router in rollout controller for simpler proxy server usage#1013zhanghaotong wants to merge 2 commits into
Conversation
Signed-off-by: 皓聪 <zhanghaotong.zht@antgroup.com>
Signed-off-by: 皓聪 <zhanghaotong.zht@antgroup.com>
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the worker selection and proxy address management within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a ProxyRouter using the Strategy pattern to handle worker selection in the RolloutController, which is a great improvement for extensibility. The implementation is solid, and the addition of unit tests and documentation is appreciated.
I've found a critical issue in the fallback logic of _choose_worker that would break the round-robin behavior. I've also pointed out a significant maintainability issue in the new test file tests/test_proxy_router.py due to code duplication. Addressing these points will make this feature more robust and maintainable.
| if self._proxy_router is not None: | ||
| worker, rank, _ = self._proxy_router.route() | ||
| return worker, rank | ||
|
|
||
| # Fallback to legacy round-robin if router not initialized | ||
| fallback_selector = RoundRobinSelector() | ||
| rank = fallback_selector.select(self.workers) | ||
| return self.workers[rank], rank |
There was a problem hiding this comment.
The fallback logic for _choose_worker is incorrect. It creates a new RoundRobinSelector on every call, which will always select the first worker (rank 0) because its internal index is reset to 0 upon instantiation. This breaks the expected round-robin behavior and will cause all fallback requests to be routed to a single worker, potentially overloading it.
Since _proxy_router should be initialized before this method is called, it's safer to raise an error if it's not available, similar to how route_worker_with_proxy handles it.
| if self._proxy_router is not None: | |
| worker, rank, _ = self._proxy_router.route() | |
| return worker, rank | |
| # Fallback to legacy round-robin if router not initialized | |
| fallback_selector = RoundRobinSelector() | |
| rank = fallback_selector.select(self.workers) | |
| return self.workers[rank], rank | |
| if self._proxy_router is None: | |
| raise RuntimeError("ProxyRouter not initialized. Call initialize() first.") | |
| worker, rank, _ = self._proxy_router.route() | |
| return worker, rank |
| @dataclass | ||
| class MockWorker: | ||
| """Mock Worker for testing.""" | ||
|
|
||
| id: str | ||
| ip: str = "127.0.0.1" | ||
|
|
||
|
|
||
| class RoutingStrategy(str, Enum): | ||
| """Enumeration of available worker routing strategies.""" | ||
|
|
||
| ROUND_ROBIN = "round-robin" | ||
| RANDOM = "random" | ||
|
|
||
|
|
||
| class WorkerSelector(ABC): | ||
| """Abstract base class for worker selection strategies.""" | ||
|
|
||
| @abstractmethod | ||
| def select(self, workers: list[MockWorker]) -> int: | ||
| """Select a worker and return its rank.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def reset(self) -> None: | ||
| """Reset the selector state.""" | ||
| pass | ||
|
|
||
|
|
||
| class RoundRobinSelector(WorkerSelector): | ||
| """Round-robin worker selection strategy.""" | ||
|
|
||
| def __init__(self): | ||
| self._current_idx = 0 | ||
|
|
||
| def select(self, workers: list[MockWorker]) -> int: | ||
| """Select the next worker in round-robin order.""" | ||
| if not workers: | ||
| raise RuntimeError("No workers available to choose from.") | ||
|
|
||
| rank = self._current_idx | ||
| self._current_idx = (self._current_idx + 1) % len(workers) | ||
| return rank | ||
|
|
||
| def reset(self) -> None: | ||
| """Reset the round-robin index to 0.""" | ||
| self._current_idx = 0 | ||
|
|
||
|
|
||
| class RandomSelector(WorkerSelector): | ||
| """Random worker selection strategy.""" | ||
|
|
||
| def select(self, workers: list[MockWorker]) -> int: | ||
| """Randomly select a worker.""" | ||
| if not workers: | ||
| raise RuntimeError("No workers available to choose from.") | ||
|
|
||
| return random.randint(0, len(workers) - 1) | ||
|
|
||
| def reset(self) -> None: | ||
| """No-op for stateless random strategy.""" | ||
| pass | ||
|
|
||
|
|
||
| class ProxyRouter: | ||
| """Router for choosing workers and managing proxy addresses using Strategy Pattern.""" | ||
|
|
||
| # Strategy factory: maps RoutingStrategy enum to WorkerSelector classes | ||
| _SELECTOR_FACTORY: dict[RoutingStrategy, type[WorkerSelector]] = { | ||
| RoutingStrategy.ROUND_ROBIN: RoundRobinSelector, | ||
| RoutingStrategy.RANDOM: RandomSelector, | ||
| } | ||
|
|
||
| def __init__( | ||
| self, | ||
| workers: list[MockWorker], | ||
| proxy_addrs: list[str] | None = None, | ||
| routing_strategy: RoutingStrategy | str = RoutingStrategy.ROUND_ROBIN, | ||
| ): | ||
| """Initialize the ProxyRouter.""" | ||
| self.workers = workers | ||
| self.proxy_addrs = proxy_addrs or [] | ||
| self._proxy_enabled = bool(proxy_addrs) | ||
|
|
||
| # Convert string to enum if necessary | ||
| if isinstance(routing_strategy, str): | ||
| try: | ||
| self.strategy = RoutingStrategy(routing_strategy) | ||
| except ValueError: | ||
| valid_strategies = ", ".join([s.value for s in RoutingStrategy]) | ||
| raise ValueError( | ||
| f"Invalid routing_strategy: {routing_strategy}. " | ||
| f"Must be one of: {valid_strategies}" | ||
| ) | ||
| else: | ||
| self.strategy = routing_strategy | ||
|
|
||
| # Create the appropriate selector using the factory | ||
| selector_class = self._SELECTOR_FACTORY.get(self.strategy) | ||
| if selector_class is None: | ||
| raise ValueError( | ||
| f"No selector implementation found for strategy: {self.strategy}" | ||
| ) | ||
| self.selector = selector_class() | ||
|
|
||
| def route(self) -> tuple[MockWorker, int, str | None]: | ||
| """Choose a worker and get its proxy address.""" | ||
| # Delegate selection to the strategy | ||
| rank = self.selector.select(self.workers) | ||
| worker = self.workers[rank] | ||
|
|
||
| # Get proxy address if available | ||
| proxy_addr = ( | ||
| self.proxy_addrs[rank] | ||
| if self._proxy_enabled and rank < len(self.proxy_addrs) | ||
| else None | ||
| ) | ||
|
|
||
| return worker, rank, proxy_addr | ||
|
|
||
| def get_proxy_addr(self, rank: int) -> str | None: | ||
| """Get the proxy server address for a specific worker rank.""" | ||
| if not self._proxy_enabled or rank >= len(self.proxy_addrs): | ||
| return None | ||
| return self.proxy_addrs[rank] | ||
|
|
||
| def update_proxy_addrs(self, proxy_addrs: list[str]) -> None: | ||
| """Update the proxy addresses.""" | ||
| self.proxy_addrs = proxy_addrs | ||
| self._proxy_enabled = bool(proxy_addrs) | ||
|
|
||
| def reset(self) -> None: | ||
| """Reset the routing strategy's internal state.""" | ||
| self.selector.reset() | ||
|
|
There was a problem hiding this comment.
The test file duplicates several classes (RoutingStrategy, WorkerSelector, RoundRobinSelector, RandomSelector, ProxyRouter) from areal.infra.controller.rollout_controller, and also defines a MockWorker. This is a significant maintainability issue, as any changes to the original classes will require manual updates in the test file, and they can easily get out of sync.
Please remove the duplicated class definitions (lines 17-143) and the MockWorker class (lines 9-14). Instead, import the necessary components from the application code.
# At the top of tests/test_proxy_router.py
# Add these imports
from areal.api import Worker
from areal.infra.controller.rollout_controller import (
ProxyRouter,
RandomSelector,
RoundRobinSelector,
RoutingStrategy,
)
# In your test functions, replace MockWorker with Worker, for example:
# workers = [Worker(id=f"worker-{i}", ip="127.0.0.1") for i in range(3)]This makes the tests more robust and easier to maintain.
garrett4wade
left a comment
There was a problem hiding this comment.
Thanks for the contribution. While this PR creates a new abstraction for the router, it does not introduce useful features like workload-aware scheduling and effectively benchmark the performance. I think we should defer the change of API before seeing clear inference throughput improvement under a smarter scheduling strategy.
Sure, I'll write some more useful scheduling strategies. Let me first mark this PR as WIP |
|
This pull request has been automatically marked as stale because it has not had recent activity within the last 14 days. Please add a comment or push new commits to keep it active. Thank you for your contribution! |
|
This pull request has been automatically marked as stale because it has not had recent activity within the last 14 days. Please add a comment or push new commits to keep it active. Thank you for your contribution! |
Description
This PR introduces a ProxyRouter in RolloutController to centralize worker selection and proxy address resolution, making proxy server usage simpler and more extensible.
Key changes:
Related Issue
#907
Type of Change
work as expected)
Checklist
jb build docs/gemini review)Breaking Change Details (if applicable):
N/A
Additional Context
Need help? Check the Contributing Guide or ask in
GitHub Discussions!