Skip to content

Commit 1a2ae77

Browse files
committed
feat: multiple replicas for renderer and generator
1 parent 973bbb0 commit 1a2ae77

4 files changed

Lines changed: 112 additions & 10 deletions

File tree

commit.py

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,39 @@ def _parse_commitments(commitments: dict, round_number: int, schedule: Schedule,
638638
help="HuggingFace token to pass as HF_TOKEN environment variable",
639639
)
640640
@click.option("--name", "container_name", default=None, help="Custom container name (default: generator)")
641-
def start_generator_cmd(image_url: str, targon_api_key: str, hf_token: str | None, container_name: str | None) -> None:
641+
@click.option(
642+
"--container-concurrency",
643+
"container_concurrency",
644+
type=int,
645+
default=1,
646+
show_default=True,
647+
help="Maximum concurrent requests per generator replica.",
648+
)
649+
@click.option(
650+
"--min-replicas",
651+
"min_replicas",
652+
type=int,
653+
default=1,
654+
show_default=True,
655+
help="Minimum number of generator replicas.",
656+
)
657+
@click.option(
658+
"--max-replicas",
659+
"max_replicas",
660+
type=int,
661+
default=2,
662+
show_default=True,
663+
help="Maximum number of generator replicas.",
664+
)
665+
def start_generator_cmd(
666+
image_url: str,
667+
targon_api_key: str,
668+
hf_token: str | None,
669+
container_name: str | None,
670+
container_concurrency: int,
671+
min_replicas: int,
672+
max_replicas: int,
673+
) -> None:
642674
"""Start the generator container."""
643675
click.echo(f"Starting generator: {image_url}", err=True)
644676

@@ -663,6 +695,9 @@ def start_generator_cmd(image_url: str, targon_api_key: str, hf_token: str | Non
663695
health_check_path=_GENERATOR_HEALTH_CHECK_PATH,
664696
echo=lambda msg: click.echo(msg, err=True),
665697
env=env,
698+
container_concurrency=container_concurrency,
699+
min_replicas=min_replicas,
700+
max_replicas=max_replicas,
666701
)
667702
)
668703
click.echo(json.dumps({"success": True, "container_url": container_url}))
@@ -678,7 +713,36 @@ def start_generator_cmd(image_url: str, targon_api_key: str, hf_token: str | Non
678713

679714
@cli.command("start-renderer")
680715
@click.option("--targon-api-key", required=True, help="Targon API key")
681-
def start_renderer_cmd(targon_api_key: str) -> None:
716+
@click.option(
717+
"--container-concurrency",
718+
"container_concurrency",
719+
type=int,
720+
default=1,
721+
show_default=True,
722+
help="Maximum concurrent requests per renderer replica.",
723+
)
724+
@click.option(
725+
"--min-replicas",
726+
"min_replicas",
727+
type=int,
728+
default=1,
729+
show_default=True,
730+
help="Minimum number of renderer replicas.",
731+
)
732+
@click.option(
733+
"--max-replicas",
734+
"max_replicas",
735+
type=int,
736+
default=2,
737+
show_default=True,
738+
help="Maximum number of renderer replicas.",
739+
)
740+
def start_renderer_cmd(
741+
targon_api_key: str,
742+
container_concurrency: int,
743+
min_replicas: int,
744+
max_replicas: int,
745+
) -> None:
682746
"""Start the renderer container."""
683747
click.echo(f"Starting renderer: {_RENDER_IMAGE_URL}", err=True)
684748

@@ -692,6 +756,9 @@ def start_renderer_cmd(targon_api_key: str) -> None:
692756
port=_RENDER_PORT,
693757
health_check_path=_RENDER_HEALTH_CHECK_PATH,
694758
echo=lambda msg: click.echo(msg, err=True),
759+
container_concurrency=container_concurrency,
760+
min_replicas=min_replicas,
761+
max_replicas=max_replicas,
695762
)
696763
)
697764
click.echo(json.dumps({"success": True, "container_url": container_url}))
@@ -709,14 +776,22 @@ def start_renderer_cmd(targon_api_key: str) -> None:
709776
@click.option("--data-dir", required=True, help="Path to the directory containing the .ply files to render")
710777
@click.option("--endpoint", required=True, help="Renderer endpoint URL.")
711778
@click.option("--output-dir", default="results", help="Path to the directory where the rendered images will be saved.")
712-
def render_cmd(data_dir: str, endpoint: str, output_dir: str) -> None:
779+
@click.option(
780+
"--concurrency",
781+
type=int,
782+
default=1,
783+
show_default=True,
784+
help="Maximum number of files rendered concurrently.",
785+
)
786+
def render_cmd(data_dir: str, endpoint: str, output_dir: str, concurrency: int) -> None:
713787
"""Render the .ply files using the renderer endpoint."""
714788
click.echo(f"Rendering {data_dir} with endpoint {endpoint}", err=True)
715789
try:
716790
renderer = Renderer(
717791
data_dir=data_dir,
718792
endpoint=endpoint,
719793
output_dir=output_dir,
794+
concurrency=concurrency,
720795
)
721796
asyncio.run(renderer.render())
722797
click.echo(json.dumps({"success": True, "output_dir": output_dir}))
@@ -833,11 +908,19 @@ async def _stop() -> None:
833908
@click.option("--endpoint", required=True, help="Generator endpoint URL.")
834909
@click.option("--seed", required=True, help="Seed for generation.")
835910
@click.option("--output-folder", default="results", help="Folder path where generated .ply files will be saved.")
911+
@click.option(
912+
"--concurrency",
913+
type=int,
914+
default=8,
915+
show_default=True,
916+
help="Maximum number of prompts / HTTP requests processed concurrently.",
917+
)
836918
def generate_cmd(
837919
prompts_file: str,
838920
endpoint: str,
839921
seed: str,
840922
output_folder: str,
923+
concurrency: int,
841924
) -> None:
842925
"""Generate models using the generator endpoint."""
843926
# Read prompts from prompt file
@@ -864,6 +947,7 @@ def generate_cmd(
864947
seed=int(seed),
865948
output_folder=Path(output_folder),
866949
echo=lambda msg: click.echo(msg, err=True),
950+
concurrency=concurrency,
867951
)
868952

869953
try:
@@ -888,6 +972,9 @@ async def _create_container(
888972
echo: Callable[[str], None],
889973
args: list[str] | None = None,
890974
env: dict[str, str] | None = None,
975+
container_concurrency: int = 1,
976+
min_replicas: int = 1,
977+
max_replicas: int = 2,
891978
) -> str:
892979
"""
893980
Create and deploy a container on Targon.
@@ -914,7 +1001,9 @@ async def _create_container(
9141001
image=image_url,
9151002
resource_name=resource_name,
9161003
port=port,
917-
container_concurrency=1,
1004+
container_concurrency=container_concurrency,
1005+
min_replicas=min_replicas,
1006+
max_replicas=max_replicas,
9181007
args=args,
9191008
env=env,
9201009
)

generator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(
1414
seed: int,
1515
output_folder: Path,
1616
echo: Callable[[str], None] | None = None,
17+
concurrency: int = 8,
1718
) -> None:
1819
"""
1920
Initialize the Generator.
@@ -23,11 +24,13 @@ def __init__(
2324
seed: Seed value for generation (ensures reproducibility)
2425
output_folder: Path to folder where .ply files will be saved
2526
echo: Optional callback function for logging messages
27+
concurrency: Max concurrent prompts / HTTP requests
2628
"""
2729
self.endpoint = endpoint
2830
self.seed = seed
2931
self.output_folder = Path(output_folder)
3032
self.echo = echo or (lambda msg: None)
33+
self.concurrency = concurrency
3134

3235
# Create output folder if it doesn't exist
3336
self.output_folder.mkdir(parents=True, exist_ok=True)
@@ -46,8 +49,8 @@ async def generate_all(self, prompts: list[str]) -> None:
4649
tasks = []
4750
try:
4851
self.echo(f"Processing {len(prompts)} prompts...")
49-
request_sem = asyncio.Semaphore(1) # Using semaphores to limit request to one at a time.
50-
process_sem = asyncio.Semaphore(8) # Limiting request to control traffic
52+
request_sem = asyncio.Semaphore(self.concurrency)
53+
process_sem = asyncio.Semaphore(self.concurrency)
5154
tasks = [
5255
asyncio.create_task(
5356
self._process_prompt(

renderer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,26 @@
88

99

1010
class Renderer:
11-
def __init__(self, *, endpoint: str, data_dir: str, output_dir: str) -> None:
11+
def __init__(
12+
self,
13+
*,
14+
endpoint: str,
15+
data_dir: str,
16+
output_dir: str,
17+
concurrency: int = 1,
18+
) -> None:
1219
self._endpoint = endpoint
1320
self._data_dir = Path(data_dir)
1421
self._output_dir = Path(output_dir)
1522
self._output_dir.mkdir(parents=True, exist_ok=True)
23+
self._concurrency = concurrency
1624

1725
async def render(self) -> None:
1826
"""Render the .ply and .glb files using the renderer endpoint."""
1927
click.echo(f"Rendering {self._data_dir} with endpoint {self._endpoint}", err=True)
2028
tasks: list[asyncio.Task] = []
2129
try:
22-
process_sem = asyncio.Semaphore(1)
30+
process_sem = asyncio.Semaphore(self._concurrency)
2331
# Collect both .ply and .glb files
2432
ply_files = list(self._data_dir.glob("*.ply"))
2533
glb_files = list(self._data_dir.glob("*.glb"))

targon_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class ContainerDeployConfig(BaseModel):
2424

2525
image: str
2626
container_concurrency: int
27+
min_replicas: int = 1
28+
max_replicas: int = 2
2729
resource_name: str = "h200-small"
2830
port: int = 10006
2931
args: list[str] | None = None
@@ -96,8 +98,8 @@ async def deploy_container(self, name: str, config: ContainerDeployConfig) -> No
9698
visibility="external",
9799
),
98100
scaling=AutoScalingConfig(
99-
min_replicas=1,
100-
max_replicas=1,
101+
min_replicas=config.min_replicas,
102+
max_replicas=config.max_replicas,
101103
container_concurrency=config.container_concurrency,
102104
target_concurrency=config.container_concurrency,
103105
),

0 commit comments

Comments
 (0)