|
2 | 2 | import shutil |
3 | 3 | import subprocess |
4 | 4 | import tempfile |
| 5 | +import click |
5 | 6 | import requests |
6 | 7 | import yaml |
7 | 8 |
|
8 | 9 | class NuGetSyncer: |
9 | | - def __init__(self): |
| 10 | + def __init__(self, packages_file, package_filter=None): |
10 | 11 | self.NUGET_FEED_URL = "https://pkgs.dev.azure.com/Keyfactor/_packaging/KeyfactorPackages/nuget/v3/index.json" |
11 | 12 | self.GITHUB_NUGET_URL = "https://nuget.pkg.github.com/keyfactor/index.json" |
12 | 13 | self.GITHUB_TOKEN = os.getenv("GH_NUGET_TOKEN", os.getenv("GITHUB_TOKEN")) |
13 | 14 | self.AZ_DEVOPS_PAT = os.getenv("AZ_DEVOPS_PAT") |
14 | 15 | self.TMP_DIR = "nupkgs" |
15 | | - self.PACKAGES_YML = "packages.yml" |
16 | 16 | self.GITHUB_NUGET_BASE = "https://nuget.pkg.github.com/keyfactor" |
17 | | - self.allowed_packages = self.load_allowed_packages() |
| 17 | + self.package_filter = package_filter |
| 18 | + self.allowed_packages = self.load_allowed_packages(packages_file) |
18 | 19 | self._github_versions_cache = {} |
19 | 20 | os.makedirs(self.TMP_DIR, exist_ok=True) |
20 | 21 |
|
21 | | - def load_allowed_packages(self): |
| 22 | + def load_allowed_packages(self, packages_file): |
22 | 23 | try: |
23 | | - with open(self.PACKAGES_YML, 'r') as file: |
24 | | - yaml_data = yaml.safe_load(file) |
25 | | - return yaml_data.get('packages', []) |
| 24 | + with open(packages_file, 'r') as f: |
| 25 | + packages = yaml.safe_load(f).get('packages', []) or [] |
26 | 26 | except Exception as e: |
27 | | - print(f"Error loading packages.yml: {e}") |
28 | | - return set() |
29 | | - |
30 | | - def get_all_packages_and_versions(self): |
31 | | - # This function should be implemented to get all allowed packages and their versions |
32 | | - # For now, let's assume you have a list of versions for each package in packages.yml |
33 | | - # You can extend this to read from example_versions.json or another source |
34 | | - # Example: { 'PackageA': ['1.0.0', '2.0.0'], ... } |
35 | | - # For demonstration, we'll just return the allowed packages with no versions |
36 | | - if isinstance(self.allowed_packages, str): |
37 | | - return f"{self.allowed_packages}".split(",") |
38 | | - elif isinstance(self.allowed_packages, list) or isinstance(self.allowed_packages, set): |
39 | | - return self.allowed_packages |
40 | | - return {pkg: [] for pkg in self.allowed_packages} |
| 27 | + click.echo(f"Error loading {packages_file}: {e}", err=True) |
| 28 | + return [] |
| 29 | + if self.package_filter: |
| 30 | + packages = [p for p in packages if p.get('name', '').lower() == self.package_filter.lower()] |
| 31 | + if not packages: |
| 32 | + raise click.BadParameter(f"Package '{self.package_filter}' not found in {packages_file}") |
| 33 | + return packages |
41 | 34 |
|
42 | 35 | def get_github_published_versions(self, name): |
43 | 36 | """Fetch the list of versions already published to GitHub Packages for a given package.""" |
@@ -176,14 +169,13 @@ def upload_all_packages_to_github(self): |
176 | 169 |
|
177 | 170 | def sync_packages(self): |
178 | 171 | if not self.allowed_packages: |
179 | | - print("No packages specified in packages.yml. Nothing to sync.") |
| 172 | + click.echo("No packages specified. Nothing to sync.") |
180 | 173 | return |
181 | | - print(f"Will sync the following packages: {self.allowed_packages}") |
182 | | - packages_and_versions = self.get_all_packages_and_versions() |
| 174 | + click.echo(f"Will sync the following packages: {[p.get('name', p) for p in self.allowed_packages]}") |
183 | 175 | skipped = 0 |
184 | 176 | successful = 0 |
185 | 177 | failed = 0 |
186 | | - for pkg in packages_and_versions: |
| 178 | + for pkg in self.allowed_packages: |
187 | 179 | pkg_name = pkg.get('name', pkg) |
188 | 180 | versions = pkg.get('versions', []) |
189 | 181 | published = self.get_github_published_versions(pkg_name) |
@@ -216,11 +208,143 @@ def sync_packages(self): |
216 | 208 | print(f" Skipped: {skipped}") |
217 | 209 | print(f" Failed: {failed}") |
218 | 210 |
|
219 | | -if __name__ == "__main__": |
220 | | - syncer = NuGetSyncer() |
| 211 | +AZDO_FEED_BASE = "https://pkgs.dev.azure.com/Keyfactor/_packaging/KeyfactorPackages/nuget/v3/flat2" |
| 212 | + |
| 213 | + |
| 214 | +def _validate_versions(name, versions, az_pat): |
| 215 | + """Check that each version exists in the Azure DevOps feed.""" |
| 216 | + resp = requests.get( |
| 217 | + f"{AZDO_FEED_BASE}/{name.lower()}/index.json", |
| 218 | + auth=("any", az_pat), |
| 219 | + timeout=15, |
| 220 | + ) |
| 221 | + if resp.status_code != 200: |
| 222 | + raise click.ClickException(f"Package '{name}' not found in Azure DevOps feed.") |
| 223 | + available = set(resp.json().get("versions", [])) |
| 224 | + missing = [v for v in versions if v not in available] |
| 225 | + if missing: |
| 226 | + raise click.ClickException( |
| 227 | + f"Version(s) not found in Azure DevOps feed: {', '.join(missing)}\n" |
| 228 | + f"Available: {', '.join(sorted(available))}" |
| 229 | + ) |
| 230 | + |
| 231 | + |
| 232 | +def _write_versions_to_file(packages_file, name, versions): |
| 233 | + """ |
| 234 | + Insert versions into packages_file using line-based editing to preserve |
| 235 | + all comments and formatting. Returns (added, skipped) version lists. |
| 236 | + """ |
| 237 | + with open(packages_file, 'r') as f: |
| 238 | + lines = f.readlines() |
| 239 | + |
| 240 | + # Parse current state to know which versions already exist |
| 241 | + with open(packages_file, 'r') as f: |
| 242 | + data = yaml.safe_load(f) |
| 243 | + packages = data.get('packages') or [] |
| 244 | + existing = next((p for p in packages if p.get('name', '').lower() == name.lower()), None) |
| 245 | + |
| 246 | + already_present = {str(v) for v in existing.get('versions', [])} if existing else set() |
| 247 | + to_add = [v for v in versions if v not in already_present] |
| 248 | + skipped = [v for v in versions if v in already_present] |
| 249 | + |
| 250 | + if not to_add: |
| 251 | + return [], skipped |
| 252 | + |
| 253 | + if existing: |
| 254 | + # Find the last version line for this package and insert after it. |
| 255 | + # Locate the `- name: <name>` line first. |
| 256 | + pkg_line = next( |
| 257 | + (i for i, l in enumerate(lines) if l.strip().lstrip('- ').startswith(f'name: {name}')), |
| 258 | + None, |
| 259 | + ) |
| 260 | + if pkg_line is None: |
| 261 | + raise click.ClickException(f"Could not locate '{name}' in {packages_file}.") |
| 262 | + |
| 263 | + # Walk forward to find the last `- <version>` line inside this package block. |
| 264 | + last_ver_line = None |
| 265 | + ver_indent = None |
| 266 | + in_versions = False |
| 267 | + for i in range(pkg_line + 1, len(lines)): |
| 268 | + stripped = lines[i].strip() |
| 269 | + if not stripped or stripped.startswith('#'): |
| 270 | + continue |
| 271 | + if stripped == 'versions:': |
| 272 | + in_versions = True |
| 273 | + continue |
| 274 | + if in_versions: |
| 275 | + if stripped.startswith('- ') and not stripped.startswith('- name:'): |
| 276 | + last_ver_line = i |
| 277 | + ver_indent = len(lines[i]) - len(lines[i].lstrip()) |
| 278 | + else: |
| 279 | + break # hit next key or next package |
| 280 | + elif stripped.startswith('- name:'): |
| 281 | + break # hit next package without finding versions |
| 282 | + |
| 283 | + if last_ver_line is None: |
| 284 | + raise click.ClickException(f"Could not find versions block for '{name}'.") |
| 285 | + |
| 286 | + for v in reversed(to_add): |
| 287 | + lines.insert(last_ver_line + 1, ' ' * ver_indent + f'- {v}\n') |
| 288 | + else: |
| 289 | + # Append new package block at the end of the file. |
| 290 | + if lines and not lines[-1].endswith('\n'): |
| 291 | + lines.append('\n') |
| 292 | + lines.append(f' - name: {name}\n') |
| 293 | + lines.append(f' versions:\n') |
| 294 | + for v in to_add: |
| 295 | + lines.append(f' - {v}\n') |
221 | 296 |
|
222 | | - # Option 1: Download and upload packages from packages.yml |
| 297 | + with open(packages_file, 'w') as f: |
| 298 | + f.writelines(lines) |
| 299 | + |
| 300 | + return to_add, skipped |
| 301 | + |
| 302 | + |
| 303 | +@click.group() |
| 304 | +def cli(): |
| 305 | + """Manage NuGet package sync between Azure DevOps and GitHub Packages.""" |
| 306 | + pass |
| 307 | + |
| 308 | + |
| 309 | +@cli.command() |
| 310 | +@click.argument("packages_file", type=click.Path(exists=True, dir_okay=False)) |
| 311 | +@click.option("--package", default=None, help="Sync only this package name (must exist in the packages file).") |
| 312 | +def sync(packages_file, package): |
| 313 | + """Sync packages from Azure DevOps to GitHub Packages.""" |
| 314 | + syncer = NuGetSyncer(packages_file, package_filter=package) |
223 | 315 | syncer.sync_packages() |
224 | 316 |
|
225 | | - # Option 2: Upload all existing packages in nupkgs directory |
226 | | - # syncer.upload_all_packages_to_github() |
| 317 | + |
| 318 | +@cli.command() |
| 319 | +@click.argument("packages_file", type=click.Path(dir_okay=False)) |
| 320 | +@click.argument("name") |
| 321 | +@click.argument("versions", nargs=-1, required=True) |
| 322 | +@click.option("--skip-validate", is_flag=True, default=False, |
| 323 | + help="Skip Azure DevOps feed validation.") |
| 324 | +def register(packages_file, name, versions, skip_validate): |
| 325 | + """Add NAME with one or more VERSIONS to PACKAGES_FILE. |
| 326 | +
|
| 327 | + Validates each version exists in the Azure DevOps feed before writing. |
| 328 | + Requires AZ_DEVOPS_PAT env var unless --skip-validate is set. |
| 329 | + """ |
| 330 | + if not skip_validate: |
| 331 | + az_pat = os.getenv("AZ_DEVOPS_PAT") |
| 332 | + if not az_pat: |
| 333 | + raise click.ClickException( |
| 334 | + "AZ_DEVOPS_PAT env var is required for validation. Use --skip-validate to bypass." |
| 335 | + ) |
| 336 | + click.echo(f"Validating {name} against Azure DevOps feed...") |
| 337 | + _validate_versions(name, versions, az_pat) |
| 338 | + |
| 339 | + added, skipped = _write_versions_to_file(packages_file, name, versions) |
| 340 | + |
| 341 | + if added: |
| 342 | + click.echo(f"Registered {name}: {', '.join(added)}") |
| 343 | + if skipped: |
| 344 | + click.echo(f"Already in {packages_file}, skipped: {', '.join(skipped)}") |
| 345 | + if not added and not skipped: |
| 346 | + click.echo("Nothing to register.") |
| 347 | + |
| 348 | + |
| 349 | +if __name__ == "__main__": |
| 350 | + cli() |
0 commit comments