diff --git a/kernels/src/kernels/utils.py b/kernels/src/kernels/utils.py index 829c740f..eec11e52 100644 --- a/kernels/src/kernels/utils.py +++ b/kernels/src/kernels/utils.py @@ -627,31 +627,41 @@ def _get_hf_api(user_agent: str | dict | None = None) -> HfApi: user_agent_str = "" if not constants.HF_HUB_DISABLE_TELEMETRY: + parts: list[str] = [] + # User-defined info if isinstance(user_agent, dict): - user_agent_str = "; ".join(f"{k}/{v}" for k, v in user_agent.items()) - if isinstance(user_agent, str): - user_agent_str = user_agent + parts.extend(f"{k}/{v}" for k, v in user_agent.items()) + elif isinstance(user_agent, str) and user_agent: + parts.append(user_agent) # System info python = ".".join(platform.python_version_tuple()[:2]) backend = _select_backend(None).variant_str - user_agent_str += ( - f"; kernels/{__version__}; python/{python}; backend/{backend}; platform/{_platform()}; file_type/kernel" + parts.extend( + [ + f"kernels/{__version__}", + f"python/{python}", + f"backend/{backend}", + f"platform/{_platform()}", + "file_type/kernel", + ] ) if has_torch: import torch - user_agent_str += f"; torch/{torch.__version__}" + parts.append(f"torch/{torch.__version__}") if has_tvm_ffi: import tvm_ffi - user_agent_str += f"; tvm-ffi/{tvm_ffi.__version__}" + parts.append(f"tvm-ffi/{tvm_ffi.__version__}") # Add glibc version if available glibc = glibc_version() if glibc is not None: - user_agent_str += f"; glibc/{glibc}" + parts.append(f"glibc/{glibc}") + + user_agent_str = "; ".join(parts) return HfApi(library_name="kernels", library_version=__version__, user_agent=user_agent_str) diff --git a/kernels/tests/test_user_agent.py b/kernels/tests/test_user_agent.py index 11faafeb..c87e6287 100644 --- a/kernels/tests/test_user_agent.py +++ b/kernels/tests/test_user_agent.py @@ -69,3 +69,16 @@ def test_platform_format(): parts = plat.split("-") assert len(parts) == 2 assert parts[1] in ("linux", "darwin", "windows") + + +def test_user_agent_no_leading_or_empty_segment(): + # Regression: when no caller-supplied user_agent is passed, the resulting + # string must not start with a separator and must not contain empty + # segments. Empty segments downstream produce malformed User-Agent headers + # (e.g. trailing "; ") which strict HTTP clients reject. + for ua_input in (None, "", {}): + api = _get_hf_api(user_agent=ua_input) + ua = api.user_agent + assert not ua.startswith(";"), f"user_agent must not start with ';': {ua!r}" + for segment in ua.split(";"): + assert segment.strip() != "", f"empty segment found in user_agent: {ua!r}"