Skip to content

Commit aeb7006

Browse files
authored
Merge pull request #394 from hud-evals/hud-1069
Show accessible private templates in hud init
2 parents 35d6b36 + 705c596 commit aeb7006

1 file changed

Lines changed: 114 additions & 41 deletions

File tree

hud/cli/init.py

Lines changed: 114 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import questionary
1313
import typer
1414

15+
from hud.cli.utils.api import hud_headers
16+
from hud.settings import settings
1517
from hud.utils.hud_console import HUDConsole
1618

1719
# Presets mapping to public GitHub repositories under hud-evals org
@@ -22,6 +24,8 @@
2224
"blank": "hud-blank",
2325
"deep-research": "hud-deepresearch",
2426
"browser": "hud-browser",
27+
"remote-browser": "hud-remote-browser",
28+
"coding": "coding-template",
2529
"rubrics": "hud-rubrics",
2630
"verilog-coding-template": "verilog-coding-template",
2731
"data-science-template": "data-science-template",
@@ -86,34 +90,53 @@ def _replace_placeholders(target_dir: Path, env_name: str) -> list[str]:
8690
return modified_files
8791

8892

89-
def _prompt_for_preset() -> str | None:
93+
def _fetch_available_templates() -> tuple[list[dict], list[dict]]:
94+
"""Fetch available templates from the HUD API.
95+
96+
Returns (public_templates, private_templates). Falls back to empty
97+
private list if the API is unreachable or the user has no API key.
98+
"""
99+
if not settings.api_key:
100+
return [], []
101+
102+
try:
103+
with httpx.Client(timeout=10) as client:
104+
resp = client.get(
105+
f"{settings.hud_api_url}/templates/available",
106+
headers=hud_headers(),
107+
)
108+
if resp.status_code != 200:
109+
return [], []
110+
data = resp.json()
111+
return data.get("public_templates", []), data.get("private_templates", [])
112+
except Exception:
113+
return [], []
114+
115+
116+
def _prompt_for_preset() -> tuple[str, bool] | None:
90117
"""Ask the user to choose a preset when not provided.
91118
92-
Returns None if the user cancels the selection.
119+
Returns (preset_id, is_private) or None if the user cancels.
93120
"""
121+
# Fetch private templates from API
122+
_, private_templates = _fetch_available_templates()
123+
94124
try:
95-
choices = [
96-
{"name": "blank", "message": "blank"},
97-
{"name": "browser", "message": "browser"},
98-
{"name": "deep-research", "message": "deep-research"},
99-
{"name": "rubrics", "message": "rubrics"},
100-
{"name": "verilog-coding-template", "message": "verilog-coding-template"},
101-
{"name": "data-science-template", "message": "data-science-template"},
125+
choices = [questionary.Choice(title=key, value=(key, False)) for key in PRESET_MAP] + [
126+
questionary.Choice(title=t["id"], value=(t["id"], True)) for t in private_templates
102127
]
103-
display_choices = [c["message"] for c in choices]
128+
104129
selected = questionary.select(
105-
"Choose a preset", choices=display_choices, default=display_choices[0]
130+
"Choose a preset",
131+
choices=choices,
106132
).ask()
107133
if not selected:
108134
return None # User cancelled
109-
for c in choices:
110-
if c["message"] == selected:
111-
return c["name"]
112-
return "blank"
135+
return selected
113136
except KeyboardInterrupt:
114137
return None # User pressed Ctrl+C
115138
except Exception:
116-
return "blank"
139+
return ("blank", False)
117140

118141

119142
def _download_tarball_repo(
@@ -142,6 +165,32 @@ def _download_tarball_repo(
142165
tmp_file.write(chunk)
143166
tmp_path = Path(tmp_file.name)
144167

168+
_extract_tarball(tmp_path, dest_dir, files_created)
169+
170+
171+
def _download_private_template(template_id: str, dest_dir: Path, files_created: list[str]) -> None:
172+
"""Download a private template tarball from the HUD API."""
173+
url = f"{settings.hud_api_url}/templates/private/{template_id}/download"
174+
175+
with (
176+
tempfile.NamedTemporaryFile(delete=False) as tmp_file,
177+
httpx.Client(timeout=120) as client,
178+
client.stream("GET", url, headers=hud_headers()) as resp,
179+
):
180+
if resp.status_code == 403:
181+
raise RuntimeError("Access denied: your team does not have access to this template.")
182+
if resp.status_code != 200:
183+
raise RuntimeError(f"Failed to download private template (HTTP {resp.status_code})")
184+
for chunk in resp.iter_bytes():
185+
if chunk:
186+
tmp_file.write(chunk)
187+
tmp_path = Path(tmp_file.name)
188+
189+
_extract_tarball(tmp_path, dest_dir, files_created)
190+
191+
192+
def _extract_tarball(tmp_path: Path, dest_dir: Path, files_created: list[str]) -> None:
193+
"""Extract a tarball into dest_dir, stripping the top-level directory."""
145194
try:
146195
with tarfile.open(tmp_path, mode="r:gz") as tar:
147196
members = tar.getmembers()
@@ -191,15 +240,26 @@ def create_environment(
191240

192241
hud_console = HUDConsole()
193242

243+
is_private = False
244+
194245
# Choose preset
195246
if preset:
196-
preset_normalized = preset.strip().lower()
247+
preset_stripped = preset.strip()
248+
preset_normalized = preset_stripped.lower()
249+
# Check if the preset matches a private template (case-insensitive)
250+
_, private_templates = _fetch_available_templates()
251+
for t in private_templates:
252+
if t["id"].lower() == preset_normalized:
253+
# Preserve the original API ID for case-sensitive downstream use
254+
preset_normalized = t["id"]
255+
is_private = True
256+
break
197257
else:
198258
preset_result = _prompt_for_preset()
199259
if preset_result is None:
200260
# User cancelled the selection
201261
raise typer.Exit(0)
202-
preset_normalized = preset_result
262+
preset_normalized, is_private = preset_result
203263

204264
# If no name is provided, use the preset name as the environment name
205265
if name is None:
@@ -209,7 +269,7 @@ def create_environment(
209269
# Always create a new directory based on the name
210270
target_dir = Path.cwd() / name if directory == "." else Path(directory) / name
211271

212-
if preset_normalized not in PRESET_MAP:
272+
if not is_private and preset_normalized not in PRESET_MAP:
213273
available = ", ".join(sorted(PRESET_MAP.keys()))
214274
hud_console.warning(
215275
f"Unknown preset '{preset_normalized}', defaulting to 'blank' (available: {available})"
@@ -225,40 +285,53 @@ def create_environment(
225285
else:
226286
hud_console.warning(f"Overwriting existing files in {target_dir}")
227287

228-
# Download preset from GitHub
229-
repo_name = PRESET_MAP[preset_normalized]
230-
if repo_name is None:
231-
hud_console.error("Internal error: preset mapping missing repo name")
232-
raise typer.Exit(1)
233-
234288
hud_console.header(f"Initializing HUD Environment: {name} (preset: {preset_normalized})")
235-
hud_console.section_title("Downloading template from GitHub")
236-
source_url = f"https://github.com/{GITHUB_OWNER}/{repo_name}"
237-
hud_console.info("Source: " + source_url)
238-
239289
target_dir.mkdir(parents=True, exist_ok=True)
240290

241291
started = time.time()
242292
files_created_dl: list[str] = []
243-
try:
244-
_download_tarball_repo(
245-
owner=GITHUB_OWNER,
246-
repo=repo_name,
247-
ref=GITHUB_BRANCH,
248-
dest_dir=target_dir,
249-
files_created=files_created_dl,
250-
)
251-
except Exception as e:
252-
hud_console.error(f"Failed to download preset '{preset_normalized}': {e}")
253-
raise typer.Exit(1) from None
293+
294+
if is_private:
295+
hud_console.section_title("Downloading private template from HUD")
296+
try:
297+
_download_private_template(
298+
template_id=preset_normalized,
299+
dest_dir=target_dir,
300+
files_created=files_created_dl,
301+
)
302+
except Exception as e:
303+
hud_console.error(f"Failed to download private template '{preset_normalized}': {e}")
304+
raise typer.Exit(1) from None
305+
else:
306+
# Download preset from GitHub
307+
repo_name = PRESET_MAP[preset_normalized]
308+
if repo_name is None:
309+
hud_console.error("Internal error: preset mapping missing repo name")
310+
raise typer.Exit(1)
311+
312+
hud_console.section_title("Downloading template from GitHub")
313+
source_url = f"https://github.com/{GITHUB_OWNER}/{repo_name}"
314+
hud_console.info("Source: " + source_url)
315+
316+
try:
317+
_download_tarball_repo(
318+
owner=GITHUB_OWNER,
319+
repo=repo_name,
320+
ref=GITHUB_BRANCH,
321+
dest_dir=target_dir,
322+
files_created=files_created_dl,
323+
)
324+
except Exception as e:
325+
hud_console.error(f"Failed to download preset '{preset_normalized}': {e}")
326+
raise typer.Exit(1) from None
254327

255328
duration_ms = int((time.time() - started) * 1000)
256329
hud_console.success(
257330
f"Downloaded {len(files_created_dl)} files in {duration_ms} ms into {target_dir}"
258331
)
259332

260333
# Replace placeholders in template files (only for blank preset)
261-
if preset_normalized == "blank":
334+
if preset_normalized == "blank" and not is_private:
262335
hud_console.section_title("Customizing template files")
263336
modified_files = _replace_placeholders(target_dir, name)
264337
if modified_files:

0 commit comments

Comments
 (0)