diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index de23db13..76e94efc 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -4,7 +4,6 @@ on: push: branches: [ main, master, develop ] pull_request: - branches: [ main, master, develop ] workflow_dispatch: concurrency: diff --git a/.github/workflows/pr-checks.yaml b/.github/workflows/pr-checks.yaml index 4069d45f..6557835b 100644 --- a/.github/workflows/pr-checks.yaml +++ b/.github/workflows/pr-checks.yaml @@ -2,7 +2,6 @@ name: PR Checks on: pull_request: - branches: [ main, master, develop ] concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} diff --git a/.github/workflows/syfon-backend-e2e.yaml b/.github/workflows/syfon-backend-e2e.yaml index 62bb839b..adffa94c 100644 --- a/.github/workflows/syfon-backend-e2e.yaml +++ b/.github/workflows/syfon-backend-e2e.yaml @@ -2,7 +2,6 @@ name: Syfon Backend E2E on: pull_request: - branches: [ main, master, develop ] workflow_dispatch: concurrency: @@ -34,28 +33,25 @@ jobs: with: go-version-file: git-drs/go.mod - - name: Resolve Syfon ref from go.mod + - name: Resolve Syfon commit from pinned module version id: syfon-ref working-directory: git-drs shell: bash run: | - version="$(go list -m -f '{{ .Version }}' github.com/calypr/syfon 2>/dev/null || true)" + version="$(go list -m -f '{{ .Version }}' github.com/calypr/syfon)" if [[ -z "$version" || "$version" == "" ]]; then - version="$(go list -m -f '{{ .Version }}' github.com/calypr/syfon/client)" + echo "error: could not resolve pinned syfon version" >&2 + exit 1 fi - ref="$version" - if [[ "$version" =~ -([0-9a-f]{12,})$ ]]; then - short_ref="${BASH_REMATCH[1]}" - ref="$(git ls-remote https://github.com/calypr/syfon.git "refs/heads/*" "refs/tags/*" | - awk -v short="$short_ref" 'index($1, short) == 1 && ref == "" { ref = $1 } END { print ref }')" - if [[ -z "$ref" ]]; then - echo "Could not resolve Syfon pseudo-version commit $short_ref to a full Git SHA" >&2 - exit 1 - fi + short_ref="${version##*-}" + full_ref="$(git ls-remote https://github.com/calypr/syfon.git | awk -v short="$short_ref" '$1 ~ "^" short && first == "" { first = $1 } END { if (first != "") print first; else exit 1 }')" + if [[ -z "$full_ref" ]]; then + echo "error: could not resolve full syfon commit for $short_ref" >&2 + exit 1 fi echo "version=$version" >> "$GITHUB_OUTPUT" - echo "ref=$ref" >> "$GITHUB_OUTPUT" - echo "Using Syfon module version $version; checking out calypr/syfon ref $ref" + echo "ref=$full_ref" >> "$GITHUB_OUTPUT" + echo "Using Syfon module version $version; checking out calypr/syfon ref $full_ref" - name: Check out Syfon uses: actions/checkout@v4 @@ -65,13 +61,6 @@ jobs: path: syfon fetch-depth: 1 - - name: Check out data-client - uses: actions/checkout@v4 - with: - repository: calypr/data-client - path: data-client - fetch-depth: 1 - - name: Install test prerequisites working-directory: git-drs run: | diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 82d0ec2c..38b2fc29 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -4,7 +4,6 @@ on: push: branches: [ main, master, develop ] pull_request: - branches: [ main, master, develop ] workflow_dispatch: concurrency: diff --git a/Makefile b/Makefile index e5ed78dc..3d0f94eb 100644 --- a/Makefile +++ b/Makefile @@ -48,19 +48,29 @@ lint: @echo "Running go vet..." @go vet ./... @echo "Running gofmt..." - @test -z "$$(gofmt -s -l . | tee /dev/stderr)" || (echo "Please run: gofmt -s -w ." && exit 1) + @files="$$(gofmt -s -l cmd/ internal/ tests/ git-drs.go)"; \ + if [ -n "$$files" ]; then \ + printf "%s\n" "$$files"; \ + echo "Please run: gofmt -s -w cmd/ internal/ tests/ git-drs.go"; \ + exit 1; \ + fi @echo "Running goimports..." - @test -z "$$(goimports -l . | tee /dev/stderr)" || (echo "Please run: goimports -w ." && exit 1) + @files="$$(goimports -l cmd/ internal/ tests/ git-drs.go)"; \ + if [ -n "$$files" ]; then \ + printf "%s\n" "$$files"; \ + echo "Please run: goimports -w cmd/ internal/ tests/ git-drs.go"; \ + exit 1; \ + fi @echo "Running misspell..." - @misspell -error . + @misspell -error cmd/ internal/ tests/ docs/ *.go *.md Makefile @echo "✅ All lint checks passed!" # Auto-fix formatting issues fmt: @echo "Formatting with gofmt..." - @gofmt -s -w . + @gofmt -s -w cmd/ internal/ tests/ git-drs.go @echo "Formatting with goimports..." - @goimports -w . + @goimports -w cmd/ internal/ tests/ git-drs.go @echo "✅ Formatting complete!" # Run all tests @@ -117,4 +127,4 @@ full: proto install tidy lint test website webdash clean: @rm -rf ./bin ./pkg ./test_tmp ./build ./buildtools -.PHONY: proto proto-lint website docker webdash build debug coverage coverage-clients coverage-html-full +.PHONY: proto proto-lint website docker webdash build debug coverage coverage-clients coverage-html-full test install diff --git a/README.md b/README.md index 9fa44367..eb3e5d91 100644 --- a/README.md +++ b/README.md @@ -3,143 +3,124 @@ --- # NOTICE -git-drs is not yet fully compliant with DRS. It currently works against Gen3 DRS server. Full GA4GH DRS support is expected once v1.6 of the specification has been published. +`git-drs` is not a pure GA4GH DRS client. It targets Syfon/Gen3-style DRS workflows and uses extensions where repo-scale behavior requires them. --- [![Tests](https://github.com/calypr/git-drs/actions/workflows/test.yaml/badge.svg)](https://github.com/calypr/git-drs/actions/workflows/test.yaml) -**Git/DRS orchestration with optional Git LFS compatibility** +**Git/DRS orchestration with Git-compatible pointer workflows** -Git DRS manages Git-facing DRS workflows: local metadata, Git hooks, filter behavior, lookup/register/push/pull orchestration, and optional Git LFS compatibility. Provider-specific transfer, signed URL behavior, and direct cloud inspection live in client code outside this repo. +`git-drs` manages: + +- remote Gen3/Syfon configuration +- local DRS metadata +- pointer-aware push/pull orchestration +- bucket-scoped object reference workflows ## Key Features -- **Unified Workflow**: Manage both code and large data files using standard Git commands -- **DRS Integration**: Built-in support for Gen3 DRS servers -- **Multi-Remote Support**: Work with development, staging, and production servers in one repository -- **Automatic Processing**: Files are processed automatically during commits and pushes -- **Flexible Tracking**: Track individual files, patterns, or entire directories +- unified Git/data workflow around DRS-backed pointers +- Gen3/Syfon integration +- multiple remotes in one repository +- explicit file tracking and hydration +- metadata-only reference support for existing bucket objects ## How It Works -Git DRS works alongside Git LFS when you want LFS-compatible pointers and storage, while still supporting DRS-centric workflows: +At a high level: -1. **Initialization**: Set up repository and DRS server configuration -2. **Automatic Commits**: Create DRS objects during pre-commit hooks -3. **Automatic Pushes**: Register files with DRS servers and upload to configured storage -4. **On-Demand Downloads**: Pull specific files or patterns as needed +1. configure a remote for one `organization/project` +2. let `remote add` bootstrap repo-local `git-drs` state if needed +3. track file patterns with `git drs track` +4. add and commit with normal Git +5. remove tracked pointers with `git drs rm` when you want repository deletion to reconcile with remote DRS state +6. run `git drs push` for managed metadata registration/upload plus Git push +7. hydrate pointer files later with `git drs pull` ## Quick Start -### Installation - ```bash -# Install Git LFS first -brew install git-lfs # macOS -git lfs install --skip-smudge - -# Install Git DRS -/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/calypr/git-drs/refs/heads/main/install.sh)" -- $GIT_DRS_VERSION - -# Install global Git filter configuration for git-drs git drs install -``` - -### Basic Usage - -```bash -# Initialize repository (one-time Git repo setup) -git drs init - -# Add DRS remote -git drs remote add gen3 production \ - --cred /path/to/credentials.json \ - --url https://calypr-public.ohsu.edu \ - --organization my-program \ - --project my-project \ - --bucket my-bucket - -# Required prerequisite (usually steward/admin setup): -# create bucket credentials, then map org/project to full storage roots before users run push/pull -git drs bucket add production \ - --bucket my-bucket \ - --region us-east-1 \ - --access-key "$AWS_ACCESS_KEY_ID" \ - --secret-key "$AWS_SECRET_ACCESS_KEY" \ - --s3-endpoint https://s3.amazonaws.com -git drs bucket add-organization production \ - --organization my-program \ - --path s3://my-bucket/my-program -git drs bucket add-project production \ - --organization my-program \ - --project my-project \ - --path s3://my-bucket/my-program/my-project - -# Track files -git lfs track "*.bam" +git drs remote add gen3 production HTAN_INT/BForePC --cred /path/to/credentials.json +git drs track "*.bam" git add .gitattributes - -# Add and commit files -git add my-file.bam -git commit -m "Add data file" -git push - -# Download files -git lfs pull -I "*.bam" +git add sample.bam +git commit -m "Add sample" +git drs push +git drs ls-files +git drs pull -I "*.bam" ``` -## Documentation - -For detailed setup and usage information: +## Current CLI Shape -- **[Getting Started](docs/getting-started.md)** - Repository setup and basic workflows -- **[Commands Reference](docs/commands.md)** - Complete command documentation -- **[Installation Guide](docs/installation.md)** - Platform-specific installation -- **[Troubleshooting](docs/troubleshooting.md)** - Common issues and solutions -- **[E2E Modes + Local Setup](docs/e2e-modes-and-local-setup.md)** - Local vs remote mode, server config, and end-to-end runbooks -- **[Cloud/Object Integration](docs/adding-s3-files.md)** - Adding files from provider URLs or configured bucket object keys -- **[Developer Guide](docs/developer-guide.md)** - Internals and development +The cleaned CLI intentionally removed legacy commands: -## Supported Servers +- removed: + - `git drs fetch` + - `git drs list` + - `git drs upload` + - `git drs download` +- `git drs pull` is hydration-only +- `git drs ls-files` is the local file inventory command +- `git drs remote add gen3` takes scope as `organization/project` -- **Gen3 Data Commons** (e.g., CALYPR) +Example: -## Supported Environments - -- **Local Development** environments -- **HPC Systems** (e.g., ARC) +```bash +git drs remote add gen3 production HTAN_INT/BForePC --cred /path/to/credentials.json +``` -## Commands Overview +Current command split: + +- `git drs push` is the managed data push path +- plain `git push` is plain Git only +- `git drs pull` hydrates tracked pointer files already present in the checkout +- `git drs ls-files` is the local tracked-file inventory command +- `git drs add-url` prepares pointer plus local metadata for existing provider objects + +## Bucket Mapping Model + +End users should not need to know the bucket name. + +Push and pull depend on server-side bucket mapping for the requested scope. That mapping is normally provisioned once by a steward/admin using the bucket commands. + +## Common Commands + +| Command | Description | +| --- | --- | +| `git drs install` | Install global `git-drs` filter config | +| `git drs init` | Explicitly initialize or repair repository-local `git-drs` state | +| `git drs remote add gen3 [remote] ` | Add or refresh a Gen3/Syfon remote | +| `git drs remote list` | List configured remotes | +| `git drs remote remove ` | Remove a configured DRS remote | +| `git drs remote set ` | Set the default remote | +| `git drs track ` | Track files or globs | +| `git drs untrack ` | Stop tracking files or globs | +| `git drs rm ...` | Remove tracked DRS/LFS files from Git | +| `git drs ls-files` | List tracked files and localization state | +| `git drs pull` | Hydrate pointer files in the current checkout | +| `git drs push` | Register/upload objects, reconcile committed deletes, and push refs | +| `git drs add-url` | Add an existing provider object by URL or scoped key | +| `git drs add-ref` | Add a local reference to an existing DRS object | +| `git drs query` | Query a DRS object by ID | +| `git drs copy-records` | Copy Syfon records between remotes for one scope | -| Command | Description | -| ---------------------- | ------------------------------------- | -| `git drs install` | Install global git-drs filter config | -| `git drs init` | Initialize repository | -| `git drs remote add` | Add a DRS remote server | -| `git drs remote list` | List configured remotes | -| `git drs remote set` | Set default remote | -| `git drs add-url` | Add files via provider URLs or configured bucket object keys | -| `git lfs track` | Track file patterns with LFS | -| `git lfs ls-files` | List tracked files | -| `git lfs pull` | Download tracked files | -| `git drs fetch` | Fetch metadata from DRS server | -| `git drs push` | Push objects to DRS server | +## Documentation -Use `--help` with any command for details. See [Commands Reference](docs/commands.md) for complete documentation. +- [Getting Started](docs/getting-started.md) +- [Commands Reference](docs/commands.md) +- [Troubleshooting](docs/troubleshooting.md) +- [Developer Guide](docs/developer-guide.md) +- [GA4GH DRS Scalability Gaps](docs/ga4gh-drs-scalability-gaps.md) ## Requirements -- Git LFS installed and configured -- Access credentials for your DRS server -- Go 1.24+ (for building from source) +- Git +- access credentials for the target Gen3/Syfon deployment +- Go 1.26.2+ for local builds ## Support -- **Issues**: [GitHub Issues](https://github.com/calypr/git-drs/issues) -- **Releases**: [GitHub Releases](https://github.com/calypr/git-drs/releases) -- **Documentation**: See `docs/` folder for detailed guides - -## License - -This project is part of the CALYPR data commons ecosystem. +- [GitHub Issues](https://github.com/calypr/git-drs/issues) +- [GitHub Releases](https://github.com/calypr/git-drs/releases) diff --git a/attic/issue-add-include-pattern-to-git-drs-pull.md b/attic/issue-add-include-pattern-to-git-drs-pull.md new file mode 100644 index 00000000..4217ab3b --- /dev/null +++ b/attic/issue-add-include-pattern-to-git-drs-pull.md @@ -0,0 +1,51 @@ +# Add `-I "pattern"` include filter support to `git drs pull` + +## Summary +Add include-pattern filtering to `git drs pull`, similar to legacy `git lfs pull -I "pattern"` workflows. + +## Motivation +Current `git drs pull` behavior pulls based on repository resolution without a user-facing path pattern filter. Users migrating from `git lfs pull -I` expect selective hydration of files by glob/path. + +## Proposed UX +Support: + +```bash +git drs pull -I "results/*.txt" +git drs pull -I "*.bam" -I "data/**" +git drs pull --include "path/to/file" +``` + +Optional: +- `--exclude` parity (if desired in same change or follow-up) + +## Proposed behavior +1. Parse one or more include patterns (`-I`, `--include`). +2. Resolve candidate pointers as usual. +3. Filter by repo-relative path match before download. +4. Download only matched objects; skip others with clear logging. +5. If no pattern supplied, preserve current default behavior. + +## Scope +- `cmd/pull/main.go` CLI flags and pull selection pipeline +- pointer/path inventory layer (where path<->OID candidates are produced) +- docs: `docs/commands.md`, `docs/getting-started.md`, `docs/troubleshooting.md` +- tests for include filtering semantics + +## Acceptance criteria +- [ ] `git drs pull -I ""` works for a single pattern. +- [ ] Repeated `-I` flags are supported. +- [ ] Include matching is against repo-relative paths. +- [ ] Default `git drs pull` behavior unchanged when no `-I` is passed. +- [ ] Help text documents pattern syntax and examples. +- [ ] Unit/integration tests cover positive and negative matches. + +## Testing matrix +- Single file exact path include. +- Wildcard include (`*.bam`, `data/**`). +- Multiple `-I` values. +- No matches (should no-op cleanly and return success unless policy says otherwise). +- Mixed matched/unmatched objects in same pull run. + +## Notes +This closes a usability gap for users transitioning from `git lfs` CLI habits to `git drs` commands while keeping pull behavior explicit and predictable. + diff --git a/cmd/addref/add-ref.go b/cmd/addref/add-ref.go index 860d675f..84f3b7f0 100644 --- a/cmd/addref/add-ref.go +++ b/cmd/addref/add-ref.go @@ -9,6 +9,7 @@ import ( "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/lfs" + "github.com/calypr/git-drs/internal/remoteruntime" "github.com/spf13/cobra" ) @@ -37,7 +38,7 @@ var Cmd = &cobra.Command{ return err } - client, err := cfg.GetRemoteClient(remoteName, logger) + client, err := remoteruntime.New(cfg, remoteName, logger) if err != nil { return err } diff --git a/cmd/addurl/cache.go b/cmd/addurl/cache.go deleted file mode 100644 index 54e64b3f..00000000 --- a/cmd/addurl/cache.go +++ /dev/null @@ -1,264 +0,0 @@ -package addurl - -import ( - "context" - "crypto/sha256" - "encoding/json" - "errors" - "fmt" - "log/slog" - "maps" - "os" - "path/filepath" - "slices" - "strings" - "time" - - "github.com/calypr/git-drs/internal/gitrepo" - "github.com/calypr/git-drs/internal/precommit_cache" -) - -// updatePrecommitCache updates the project's pre-commit cache with a mapping -// from a repository-relative `pathArg` to the given LFS `oid` and records the -// external source URL. It will: -// - require a non-nil `logger` -// - open the pre-commit cache (`precommit_cache.Open`) -// - ensure cache directories exist -// - convert the supplied worktree path to a repository-relative path -// - create or update the per-path JSON entry with the current OID and timestamp -// - create or update the per-OID JSON entry listing paths that reference it, -// the external URL, and a content-change flag when the path's OID changed -// - remove the path from the previous OID entry when the content changed -// -// Parameters: -// - ctx: context for operations that may be cancellable -// - logger: a non-nil `*slog.Logger` used for warnings; if nil the function -// returns an error -// - pathArg: the worktree path to record (absolute or relative); must not be empty -// - oid: the LFS object id (string) to associate with the path -// - externalURL: optional external source URL for the object; empty string is allowed -// -// Returns an error if any cache operation, path resolution, or I/O fails. -func updatePrecommitCache(ctx context.Context, logger *slog.Logger, pathArg, oid, externalURL string) error { - if logger == nil { - return errors.New("logger is required") - } - // Open pre-commit cache. Returns a configured Cache or error. - cache, err := precommit_cache.Open(ctx) - if err != nil { - return err - } - - // Ensure cache directories exist. - if err := ensureCacheDirs(cache, logger); err != nil { - return err - } - - // Convert worktree path to repository-relative path. - relPath, err := repoRelativePath(pathArg) - if err != nil { - return err - } - - // Current timestamp in RFC3339 format (UTC). - now := time.Now().UTC().Format(time.RFC3339) - - // Read previous path entry, if any. - pathFile := cachePathEntryFile(cache, relPath) - prevEntry, prevExists, err := readPathEntry(pathFile) - if err != nil { - return err - } - // track whether content changed for this path - contentChanged := prevExists && prevEntry.LFSOID != "" && prevEntry.LFSOID != oid - - if err := writeJSONAtomic(pathFile, precommit_cache.PathEntry{ - Path: relPath, - LFSOID: oid, - UpdatedAt: now, - }); err != nil { - return err - } - - if err := upsertOIDEntry(cache, oid, relPath, externalURL, now, contentChanged); err != nil { - return err - } - - if contentChanged { - _ = removePathFromOID(cache, prevEntry.LFSOID, relPath, now) - } - - return nil -} - -// ensureCacheDirs verifies and creates the pre-commit cache directory layout -// (paths and oids directories). It logs a warning when creating a missing -// cache root. -func ensureCacheDirs(cache *precommit_cache.Cache, logger *slog.Logger) error { - if cache == nil { - return errors.New("cache is nil") - } - if _, err := os.Stat(cache.Root); err != nil { - if os.IsNotExist(err) { - logger.Warn("pre-commit cache directory missing; creating", "path", cache.Root) - } else { - return err - } - } - if err := os.MkdirAll(cache.PathsDir, 0o755); err != nil { - return fmt.Errorf("create cache paths dir: %w", err) - } - if err := os.MkdirAll(cache.OIDsDir, 0o755); err != nil { - return fmt.Errorf("create cache oids dir: %w", err) - } - return nil -} - -// repoRelativePath converts a worktree path (absolute or relative) to a -// repository-relative path. It resolves symlinks and ensures the path is -// contained within the repository root. -func repoRelativePath(pathArg string) (string, error) { - if pathArg == "" { - return "", errors.New("empty worktree path") - } - root, err := gitrepo.GitTopLevel() - if err != nil { - return "", err - } - root, err = filepath.EvalSymlinks(root) - if err != nil { - return "", err - } - clean := filepath.Clean(pathArg) - if filepath.IsAbs(clean) { - clean, err = filepath.EvalSymlinks(clean) - if err != nil { - return "", err - } - rel, err := filepath.Rel(root, clean) - if err != nil { - return "", err - } - if strings.HasPrefix(rel, "..") { - return "", fmt.Errorf("path %s is outside repo root %s", clean, root) - } - return filepath.ToSlash(rel), nil - } - return filepath.ToSlash(clean), nil -} - -// cachePathEntryFile returns the filesystem path to the JSON path-entry file -// for the given repository-relative path within the provided Cache. -func cachePathEntryFile(cache *precommit_cache.Cache, path string) string { - return filepath.Join(cache.PathsDir, precommit_cache.EncodePath(path)+".json") -} - -// cacheOIDEntryFile returns the filesystem path to the JSON OID-entry file -// for the given LFS OID. The file is named by sha256(oid) to avoid filesystem -// restrictions and collisions. -func cacheOIDEntryFile(cache *precommit_cache.Cache, oid string) string { - sum := sha256.Sum256([]byte(oid)) - return filepath.Join(cache.OIDsDir, fmt.Sprintf("%x.json", sum[:])) -} - -// readPathEntry reads and parses a JSON PathEntry from disk. It returns the -// parsed entry, a boolean indicating existence, or an error on I/O/parse -// failure. -func readPathEntry(path string) (*precommit_cache.PathEntry, bool, error) { - data, err := os.ReadFile(path) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - return nil, false, nil - } - return nil, false, err - } - var entry precommit_cache.PathEntry - if err := json.Unmarshal(data, &entry); err != nil { - return nil, false, err - } - return &entry, true, nil -} - -// readOIDEntry reads and parses a JSON OIDEntry from disk. If the file is -// missing it returns a freshly initialized entry (with LFSOID set to the -// supplied oid and UpdatedAt set to now). -func readOIDEntry(path string, oid string, now string) (*precommit_cache.OIDEntry, error) { - entry := &precommit_cache.OIDEntry{ - LFSOID: oid, - Paths: []string{}, - UpdatedAt: now, - } - data, err := os.ReadFile(path) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - return entry, nil - } - return nil, err - } - if err := json.Unmarshal(data, entry); err != nil { - return nil, err - } - entry.LFSOID = oid - return entry, nil -} - -// upsertOIDEntry creates or updates the OID entry for `oid`, ensuring `path` -// is listed among its Paths, updating ExternalURL when provided, and setting -// content-change/state fields. The updated entry is written atomically. -func upsertOIDEntry(cache *precommit_cache.Cache, oid, path, externalURL, now string, contentChanged bool) error { - if cache == nil { - return errors.New("cache is nil") - } - oidFile := cacheOIDEntryFile(cache, oid) - entry, err := readOIDEntry(oidFile, oid, now) - if err != nil { - return err - } - - pathSet := make(map[string]struct{}, len(entry.Paths)+1) - for _, p := range entry.Paths { - pathSet[p] = struct{}{} - } - if path != "" { - pathSet[path] = struct{}{} - } - entry.Paths = sortedKeys(pathSet) - entry.UpdatedAt = now - entry.ContentChange = entry.ContentChange || contentChanged - if strings.TrimSpace(externalURL) != "" { - entry.ExternalURL = externalURL - } - - return writeJSONAtomic(oidFile, entry) -} - -// removePathFromOID removes `path` from the OID entry for `oid` and writes -// the updated entry atomically. Missing entries are treated as empty. -func removePathFromOID(cache *precommit_cache.Cache, oid, path, now string) error { - if cache == nil { - return errors.New("cache is nil") - } - oidFile := cacheOIDEntryFile(cache, oid) - entry, err := readOIDEntry(oidFile, oid, now) - if err != nil { - return err - } - pathSet := make(map[string]struct{}, len(entry.Paths)) - for _, p := range entry.Paths { - if p == path { - continue - } - pathSet[p] = struct{}{} - } - entry.Paths = sortedKeys(pathSet) - entry.UpdatedAt = now - - return writeJSONAtomic(oidFile, entry) -} - -// sortedKeys returns a sorted slice of keys from the provided string-set map. -func sortedKeys(set map[string]struct{}) []string { - keys := slices.Collect(maps.Keys(set)) - slices.Sort(keys) - return keys -} diff --git a/cmd/addurl/params.go b/cmd/addurl/input.go similarity index 50% rename from cmd/addurl/params.go rename to cmd/addurl/input.go index 778485ca..bfcb0b79 100644 --- a/cmd/addurl/params.go +++ b/cmd/addurl/input.go @@ -3,12 +3,8 @@ package addurl import ( "fmt" "net/url" - "os" - "path" "strings" - "github.com/calypr/git-drs/internal/gitrepo" - sycloud "github.com/calypr/syfon/client/cloud" "github.com/spf13/cobra" ) @@ -63,18 +59,6 @@ func resolvePathArg(sourceArg string, args []string) (string, error) { return strings.Trim(strings.TrimSpace(sourceArg), "/"), nil } -func buildObjectParameters(objectURL, pathArg, sha256 string) sycloud.ObjectParameters { - return sycloud.ObjectParameters{ - ObjectURL: objectURL, - S3Region: firstNonEmpty(os.Getenv("AWS_REGION"), os.Getenv("AWS_DEFAULT_REGION"), os.Getenv("TEST_BUCKET_REGION")), - S3Endpoint: firstNonEmpty(os.Getenv("AWS_ENDPOINT_URL_S3"), os.Getenv("AWS_ENDPOINT_URL"), os.Getenv("TEST_BUCKET_ENDPOINT")), - S3AccessKey: firstNonEmpty(os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("TEST_BUCKET_ACCESS_KEY")), - S3SecretKey: firstNonEmpty(os.Getenv("AWS_SECRET_ACCESS_KEY"), os.Getenv("TEST_BUCKET_SECRET_KEY")), - SHA256: sha256, - DestinationPath: pathArg, - } -} - func looksLikeCloudURL(raw string) bool { u, err := url.Parse(strings.TrimSpace(raw)) if err != nil { @@ -91,37 +75,6 @@ func looksLikeCloudURL(raw string) bool { } } -func resolveObjectURL(input addURLInput, scope gitrepo.ResolvedBucketScope) (string, error) { - if looksLikeCloudURL(input.sourceArg) { - return input.sourceArg, nil - } - if input.scheme == "" { - return "", fmt.Errorf("object key mode requires --scheme because local bucket mappings store bucket/prefix but not provider scheme") - } - key := joinObjectKey(scope.Prefix, input.sourceArg) - switch input.scheme { - case "s3": - return fmt.Sprintf("s3://%s/%s", scope.Bucket, key), nil - case "gs", "gcs": - return fmt.Sprintf("gs://%s/%s", scope.Bucket, key), nil - case "azblob", "az": - return "", fmt.Errorf("object key mode for Azure requires a full azblob:// URL because the local mapping does not store account_name") - default: - return "", fmt.Errorf("unsupported --scheme %q (expected s3 or gs, or pass a full object URL)", input.scheme) - } -} - -func joinObjectKey(prefix, key string) string { - parts := make([]string, 0, 2) - if p := strings.Trim(strings.TrimSpace(prefix), "/"); p != "" { - parts = append(parts, p) - } - if k := strings.Trim(strings.TrimSpace(key), "/"); k != "" { - parts = append(parts, k) - } - return path.Join(parts...) -} - func firstNonEmpty(values ...string) string { for _, v := range values { v = strings.TrimSpace(v) diff --git a/cmd/addurl/inspect.go b/cmd/addurl/inspect.go new file mode 100644 index 00000000..81a95824 --- /dev/null +++ b/cmd/addurl/inspect.go @@ -0,0 +1,141 @@ +package addurl + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/calypr/git-drs/internal/remoteruntime" + sycloud "github.com/calypr/syfon/client/cloud" + syrequest "github.com/calypr/syfon/client/request" +) + +const internalInspectObjectPath = "/data/inspect" + +type inspectedObject struct { + objectURL string + info *sycloud.ObjectInfo +} + +type internalInspectObjectRequest struct { + Organization string `json:"organization,omitempty"` + Project string `json:"project,omitempty"` + Key string `json:"key,omitempty"` + Scheme string `json:"scheme,omitempty"` + ObjectURL string `json:"object_url,omitempty"` +} + +type internalInspectObjectResponse struct { + ObjectURL string `json:"object_url"` + Provider string `json:"provider"` + Bucket string `json:"bucket"` + Key string `json:"key"` + Path string `json:"path"` + SizeBytes int64 `json:"size_bytes"` + MetaSHA256 string `json:"meta_sha256,omitempty"` + ETag string `json:"etag,omitempty"` + LastModTime string `json:"last_modified,omitempty"` +} + +func inspectRemoteObjectViaServer(ctx context.Context, drsCtx *remoteruntime.GitContext, input addURLInput) (*inspectedObject, error) { + if drsCtx == nil || drsCtx.Client == nil || drsCtx.Client.Requestor() == nil { + return nil, fmt.Errorf("remote-backed add-url inspection requires a configured syfon client") + } + + req := internalInspectObjectRequest{} + target := strings.TrimSpace(input.sourceArg) + if looksLikeCloudURL(target) { + if !strings.HasPrefix(strings.ToLower(target), "s3://") { + return nil, fmt.Errorf("remote-backed add-url inspection currently supports only s3:// URLs") + } + req.ObjectURL = target + } else { + if strings.TrimSpace(input.scheme) == "" { + return nil, fmt.Errorf("object key mode requires --scheme because the remote must know which provider to inspect") + } + if strings.ToLower(strings.TrimSpace(input.scheme)) != "s3" { + return nil, fmt.Errorf("remote-backed add-url inspection currently supports only --scheme s3") + } + req.Organization = strings.TrimSpace(drsCtx.Organization) + req.Project = strings.TrimSpace(drsCtx.ProjectId) + req.Key = strings.Trim(target, "/") + req.Scheme = "s3" + } + + var resp internalInspectObjectResponse + if err := drsCtx.Client.Requestor().Do(ctx, http.MethodPost, internalInspectObjectPath, req, &resp); err != nil { + return nil, mapInspectError(target, err) + } + return &inspectedObject{ + objectURL: resp.ObjectURL, + info: &sycloud.ObjectInfo{ + Bucket: resp.Bucket, + Key: resp.Key, + Path: resp.Path, + SizeBytes: resp.SizeBytes, + MetaSHA256: resp.MetaSHA256, + ETag: resp.ETag, + LastModTime: parseInspectLastModified(resp.LastModTime), + }, + }, nil +} + +func parseInspectLastModified(raw string) time.Time { + raw = strings.TrimSpace(raw) + if raw == "" { + return time.Time{} + } + t, err := time.Parse(time.RFC3339, raw) + if err != nil { + return time.Time{} + } + return t +} + +func mapInspectError(target string, err error) error { + var respErr *syrequest.ResponseError + if !errorAsResponse(err, &respErr) { + return fmt.Errorf("remote-backed add-url inspection failed for %q: %w", target, err) + } + + body := strings.TrimSpace(respErr.Body) + switch respErr.Status { + case http.StatusBadRequest: + return fmt.Errorf("remote-backed add-url inspection rejected %q: %s", target, fallbackInspectMessage(body, "invalid request")) + case http.StatusForbidden: + return fmt.Errorf("remote-backed add-url inspection was denied for %q: %s", target, fallbackInspectMessage(body, "permission denied")) + case http.StatusNotFound: + if looksLikeInspectRouteMissing(body) { + return fmt.Errorf("remote-backed add-url inspection is unavailable on this Syfon remote; upgrade Syfon to a version that implements %s", internalInspectObjectPath) + } + return fmt.Errorf("remote-backed add-url inspection could not find %q: %s", target, fallbackInspectMessage(body, "not found")) + default: + return fmt.Errorf("remote-backed add-url inspection failed for %q: %s", target, fallbackInspectMessage(body, err.Error())) + } +} + +func looksLikeInspectRouteMissing(body string) bool { + body = strings.ToLower(strings.TrimSpace(body)) + return body == "" || + strings.Contains(body, "cannot post /data/inspect") || + strings.Contains(body, "cannot post /data/inspect/") || + strings.Contains(body, "not found") +} + +func fallbackInspectMessage(body string, fallback string) string { + if strings.TrimSpace(body) != "" { + return strings.TrimSpace(body) + } + return fallback +} + +func errorAsResponse(err error, target **syrequest.ResponseError) bool { + respErr, ok := err.(*syrequest.ResponseError) + if ok { + *target = respErr + return true + } + return false +} diff --git a/cmd/addurl/io.go b/cmd/addurl/io.go deleted file mode 100644 index 22876e1b..00000000 --- a/cmd/addurl/io.go +++ /dev/null @@ -1,120 +0,0 @@ -package addurl - -import ( - "context" - "encoding/json" - "fmt" - "os" - "path/filepath" - - sycloud "github.com/calypr/syfon/client/cloud" - "github.com/spf13/cobra" -) - -// writePointerFile writes a Git LFS pointer file at the given worktree path -// referencing the supplied oid and recording sizeBytes. It creates parent -// directories as needed and validates the path is non-empty. -func writePointerFile(pathArg, oid string, sizeBytes int64) error { - pointer := fmt.Sprintf( - "version https://git-lfs.github.com/spec/v1\noid sha256:%s\nsize %d\n", - oid, sizeBytes, - ) - if pathArg == "" { - return fmt.Errorf("empty worktree path") - } - safePath := filepath.Clean(pathArg) - dir := filepath.Dir(safePath) - if dir != "." && dir != "/" { - if err := os.MkdirAll(dir, 0755); err != nil { - return fmt.Errorf("mkdir %s: %w", dir, err) - } - } - if err := os.WriteFile(safePath, []byte(pointer), 0644); err != nil { - return fmt.Errorf("write %s: %w", safePath, err) - } - - if _, err := fmt.Fprintf(os.Stderr, "Added Git LFS pointer file at %s\n", safePath); err != nil { - return fmt.Errorf("stderr write: %w", err) - } - return nil -} - -// maybeTrackLFS ensures the supplied path is tracked by Git LFS by invoking -// the provided gitLFSTrack callback when the path is not already tracked. -// It reports the addition to stderr for user guidance. -func maybeTrackLFS(ctx context.Context, gitLFSTrack func(context.Context, string) (bool, error), pathArg string, isTracked bool) error { - if isTracked { - return nil - } - if _, err := gitLFSTrack(ctx, pathArg); err != nil { - return fmt.Errorf("git lfs track %s: %w", pathArg, err) - } - - if _, err := fmt.Fprintf(os.Stderr, "Info: Added to Git LFS. Remember to `git add %s` and `git commit ...`", pathArg); err != nil { - return fmt.Errorf("stderr write: %w", err) - } - return nil -} - -// printResolvedInfo writes a human-readable summary of resolved Git/LFS and -// cloud object information to the command's stdout for user confirmation. -func printResolvedInfo(cmd *cobra.Command, gitCommonDir, lfsRoot string, objectInfo *sycloud.ObjectInfo, pathArg string, isTracked bool, sha256 string) error { - if _, err := fmt.Fprintf(cmd.OutOrStdout(), ` -Resolved Git LFS Object Info ----------------------------- -Git common dir : %s -LFS storage : %s - -Cloud object ------------- -Bucket : %s -Key : %s -Worktree name : %s -Size (bytes) : %d -SHA256 (meta) : %s -ETag : %s -Last modified : %s - -Worktree -------------- -path : %s -tracked by LFS : %v -sha256 param : %s - -`, - gitCommonDir, - lfsRoot, - objectInfo.Bucket, - objectInfo.Key, - objectInfo.Path, - objectInfo.SizeBytes, - objectInfo.MetaSHA256, - objectInfo.ETag, - objectInfo.LastModTime.Format("2006-01-02T15:04:05Z07:00"), - pathArg, - isTracked, - sha256, - ); err != nil { - return fmt.Errorf("print resolved object info: %w", err) - } - return nil -} - -// writeJSONAtomic marshals `value` to JSON and writes it to `path` atomically -// by writing to a temporary file in the same directory and renaming it. It -// ensures parent directories exist. -func writeJSONAtomic(path string, value any) error { - dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0o755); err != nil { - return err - } - tmp := path + ".tmp" - data, err := json.Marshal(value) - if err != nil { - return err - } - if err := os.WriteFile(tmp, data, 0o644); err != nil { - return err - } - return os.Rename(tmp, path) -} diff --git a/cmd/addurl/local_state.go b/cmd/addurl/local_state.go new file mode 100644 index 00000000..d76d3a56 --- /dev/null +++ b/cmd/addurl/local_state.go @@ -0,0 +1,260 @@ +package addurl + +import ( + "context" + "crypto/sha256" + "errors" + "fmt" + "log/slog" + "os" + "path/filepath" + "strings" + "time" + + "github.com/calypr/git-drs/internal/drsobject" + "github.com/calypr/git-drs/internal/gitrepo" + "github.com/calypr/git-drs/internal/precommit_cache" + drsapi "github.com/calypr/syfon/apigen/client/drs" + sycloud "github.com/calypr/syfon/client/cloud" + "github.com/google/uuid" + "github.com/spf13/cobra" +) + +type addURLDrsFile struct { + Name string + Size int64 + Oid string +} + +func drsobjectBuilder(bucket, organization, project, storagePrefix string) drsobject.Builder { + builder := drsobject.NewBuilder(bucket, project) + builder.Organization = organization + builder.StoragePrefix = storagePrefix + return builder +} + +func writeAddURLDrsObject(builder drsobject.Builder, file addURLDrsFile, objectPath string) (*drsapi.DrsObject, error) { + existing, err := drsobject.ReadObject(gitrepo.DRSObjectsPath, file.Oid) + var drsObj *drsapi.DrsObject + if err == nil && existing != nil { + drsObj = existing + name := file.Name + drsObj.Name = &name + drsObj.Size = file.Size + } else { + drsID := uuid.NewSHA1(drsobject.UUIDNamespace, []byte(fmt.Sprintf("%s:%s", builder.Project, drsobject.NormalizeOid(file.Oid)))).String() + drsObj, err = builder.Build(file.Name, file.Oid, file.Size, drsID) + if err != nil { + return nil, fmt.Errorf("error building DRS object for oid %s: %w", file.Oid, err) + } + } + + if objectPath != "" { + if drsObj.AccessMethods != nil && len(*drsObj.AccessMethods) > 0 { + am := &(*drsObj.AccessMethods)[0] + am.AccessUrl = &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: objectPath} + } else { + drsObj.AccessMethods = &[]drsapi.AccessMethod{{ + Type: drsapi.AccessMethodTypeS3, + AccessUrl: &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: objectPath}, + }} + } + } + + if err := drsobject.WriteObject(gitrepo.DRSObjectsPath, drsObj, file.Oid); err != nil { + return nil, fmt.Errorf("error writing DRS object for oid %s: %w", file.Oid, err) + } + return drsObj, nil +} + +func placeholderOIDForUnknownSHA(etag string, sourceURL string) (string, error) { + e := strings.TrimSpace(strings.Trim(etag, `"`)) + src := strings.TrimSpace(sourceURL) + if e == "" { + return "", fmt.Errorf("etag is required for placeholder oid") + } + if src == "" { + return "", fmt.Errorf("source URL is required for placeholder oid") + } + sum := sha256.Sum256([]byte("git-drs-add-url-placeholder:v2\netag=" + e + "\nsource=" + src + "\n")) + return fmt.Sprintf("%x", sum[:]), nil +} + +// updatePrecommitCache updates the project's pre-commit cache with a mapping +// from a repository-relative `pathArg` to the given LFS `oid` and records the +// external source URL. +func updatePrecommitCache(ctx context.Context, logger *slog.Logger, pathArg, oid, externalURL string) error { + if logger == nil { + return errors.New("logger is required") + } + cache, err := precommit_cache.Open(ctx) + if err != nil { + return err + } + if err := ensureCacheDirs(cache, logger); err != nil { + return err + } + + relPath, err := repoRelativePath(pathArg) + if err != nil { + return err + } + + now := time.Now().UTC().Format(time.RFC3339) + prevEntry, prevExists, err := precommit_cache.ReadPathEntry(cache, relPath) + if err != nil { + return err + } + contentChanged := prevExists && prevEntry.LFSOID != "" && prevEntry.LFSOID != oid + + if err := precommit_cache.WritePathEntry(cache, precommit_cache.PathEntry{ + Path: relPath, + LFSOID: oid, + UpdatedAt: now, + }); err != nil { + return err + } + if err := precommit_cache.UpsertOIDPath(cache, oid, "", relPath, externalURL, now, contentChanged); err != nil { + return err + } + if contentChanged { + if err := precommit_cache.RemoveOIDPath(cache, prevEntry.LFSOID, relPath, now); err != nil { + return fmt.Errorf("remove stale OID path mapping for %s: %w", relPath, err) + } + } + return nil +} + +func ensureCacheDirs(cache *precommit_cache.Cache, logger *slog.Logger) error { + if cache == nil { + return errors.New("cache is nil") + } + if _, err := os.Stat(cache.Root); err != nil { + if os.IsNotExist(err) { + logger.Warn("pre-commit cache directory missing; creating", "path", cache.Root) + } else { + return err + } + } + if err := precommit_cache.EnsureLayout(cache); err != nil { + return fmt.Errorf("create cache layout: %w", err) + } + return nil +} + +func repoRelativePath(pathArg string) (string, error) { + if pathArg == "" { + return "", errors.New("empty worktree path") + } + root, err := gitrepo.GitTopLevel() + if err != nil { + return "", err + } + root, err = filepath.EvalSymlinks(root) + if err != nil { + return "", err + } + clean := filepath.Clean(pathArg) + if filepath.IsAbs(clean) { + clean, err = filepath.EvalSymlinks(clean) + if err != nil { + return "", err + } + rel, err := filepath.Rel(root, clean) + if err != nil { + return "", err + } + if strings.HasPrefix(rel, "..") { + return "", fmt.Errorf("path %s is outside repo root %s", clean, root) + } + return filepath.ToSlash(rel), nil + } + return filepath.ToSlash(clean), nil +} + +func writePointerFile(pathArg, oid string, sizeBytes int64) error { + pointer := fmt.Sprintf( + "version https://git-lfs.github.com/spec/v1\noid sha256:%s\nsize %d\n", + oid, sizeBytes, + ) + if pathArg == "" { + return fmt.Errorf("empty worktree path") + } + safePath := filepath.Clean(pathArg) + dir := filepath.Dir(safePath) + if dir != "." && dir != "/" { + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("mkdir %s: %w", dir, err) + } + } + if err := os.WriteFile(safePath, []byte(pointer), 0o644); err != nil { + return fmt.Errorf("write %s: %w", safePath, err) + } + + if _, err := fmt.Fprintf(os.Stderr, "Added Git LFS pointer file at %s\n", safePath); err != nil { + return fmt.Errorf("stderr write: %w", err) + } + return nil +} + +func maybeTrackLFS(ctx context.Context, gitLFSTrack func(context.Context, string) (bool, error), pathArg string, isTracked bool) error { + if isTracked { + return nil + } + if _, err := gitLFSTrack(ctx, pathArg); err != nil { + return fmt.Errorf("git lfs track %s: %w", pathArg, err) + } + + if _, err := fmt.Fprintf(os.Stderr, "Info: Added to Git LFS. Remember to `git add %s` and `git commit ...`", pathArg); err != nil { + return fmt.Errorf("stderr write: %w", err) + } + return nil +} + +func printResolvedInfo(cmd *cobra.Command, gitCommonDir, lfsRoot string, objectInfo *sycloud.ObjectInfo, pathArg string, isTracked bool, sha256 string) error { + if _, err := fmt.Fprintf(cmd.OutOrStdout(), ` +Resolved Git LFS Object Info +---------------------------- +Git common dir : %s +LFS storage : %s + +Cloud object +------------ +Bucket : %s +Key : %s +Worktree name : %s +Size (bytes) : %d +SHA256 (meta) : %s +ETag : %s +Last modified : %s + +Worktree +------------- +path : %s +tracked by LFS : %v +sha256 param : %s + +`, + gitCommonDir, + lfsRoot, + objectInfo.Bucket, + objectInfo.Key, + objectInfo.Path, + objectInfo.SizeBytes, + objectInfo.MetaSHA256, + objectInfo.ETag, + objectInfo.LastModTime.Format("2006-01-02T15:04:05Z07:00"), + pathArg, + isTracked, + sha256, + ); err != nil { + return fmt.Errorf("print resolved object info: %w", err) + } + return nil +} diff --git a/cmd/addurl/main_test.go b/cmd/addurl/main_test.go index 26a6456a..175b85d8 100644 --- a/cmd/addurl/main_test.go +++ b/cmd/addurl/main_test.go @@ -3,7 +3,6 @@ package addurl import ( "bytes" "context" - "crypto/sha256" "encoding/json" "fmt" "io" @@ -15,12 +14,11 @@ import ( "testing" "time" - "github.com/calypr/git-drs/internal/common" "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drsobject" "github.com/calypr/git-drs/internal/gitrepo" - "github.com/calypr/git-drs/internal/lfs" "github.com/calypr/git-drs/internal/precommit_cache" + "github.com/calypr/git-drs/internal/remoteruntime" sycloud "github.com/calypr/syfon/client/cloud" ) @@ -75,15 +73,26 @@ func TestRunAddURL_WritesPointerAndLFSObject(t *testing.T) { service := NewAddURLService() resetStubs := stubAddURLDeps(t, service, - func(ctx context.Context, in sycloud.ObjectParameters) (*sycloud.ObjectInfo, error) { - return &sycloud.ObjectInfo{ - Bucket: "bucket", - Key: "path/to/file.bin", - Path: "file.bin", - SizeBytes: int64(11), - MetaSHA256: "", - ETag: "abcd1234", - LastModTime: time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC), + func(ctx context.Context, drsCtx *remoteruntime.GitContext, in addURLInput) (*inspectedObject, error) { + return &inspectedObject{ + objectURL: "s3://bucket/path/to/file.bin", + info: &sycloud.ObjectInfo{ + Bucket: "bucket", + Key: "path/to/file.bin", + Path: "file.bin", + SizeBytes: int64(11), + MetaSHA256: "", + ETag: "abcd1234", + LastModTime: time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC), + }, + }, nil + }, + func(cfg *config.Config, remote config.Remote, logger *slog.Logger) (*remoteruntime.GitContext, error) { + return &remoteruntime.GitContext{ + Organization: "calypr", + ProjectId: "calypr-dev", + BucketName: "cbds", + StoragePrefix: "", }, nil }, func(path string) (bool, error) { @@ -100,9 +109,9 @@ func TestRunAddURL_WritesPointerAndLFSObject(t *testing.T) { t.Fatalf("service.Run error: %v", err) } - oid, err := lfs.SyntheticOIDFromETag("abcd1234") + oid, err := placeholderOIDForUnknownSHA("abcd1234", "s3://bucket/path/to/file.bin") if err != nil { - t.Fatalf("SyntheticOIDFromETag: %v", err) + t.Fatalf("placeholderOIDForUnknownSHA: %v", err) } pointerPath := filepath.Join(tempDir, "path/to/file.bin") @@ -120,18 +129,11 @@ func TestRunAddURL_WritesPointerAndLFSObject(t *testing.T) { } lfsObject := filepath.Join(lfsRoot, "objects", oid[0:2], oid[2:4], oid) - if _, err := os.Stat(lfsObject); err != nil { - t.Fatalf("expected LFS object at %s: %v", lfsObject, err) - } - sentinel, err := os.ReadFile(lfsObject) - if err != nil { - t.Fatalf("read sentinel: %v", err) - } - if !lfs.IsAddURLSentinelBytes(sentinel) { - t.Fatalf("expected add-url sentinel payload, got: %q", string(sentinel)) + if _, err := os.Stat(lfsObject); !os.IsNotExist(err) { + t.Fatalf("expected no local LFS object payload at %s, got err=%v", lfsObject, err) } - drsObject, err := drsobject.ReadObject(common.DRS_OBJS_PATH, oid) + drsObject, err := drsobject.ReadObject(gitrepo.DRSObjectsPath, oid) if err != nil { t.Fatalf("read drs object: %v", err) } @@ -143,43 +145,37 @@ func TestRunAddURL_WritesPointerAndLFSObject(t *testing.T) { } } -func TestParseAddURLInput_DoesNotRequireAWSFlags(t *testing.T) { - cmd := NewCommand() - in, err := parseAddURLInput(cmd, []string{"gs://bucket/path/to/file.bin"}) +func TestPlaceholderOIDForUnknownSHA(t *testing.T) { + oid1, err := placeholderOIDForUnknownSHA("etag-abc", "s3://bucket/key") if err != nil { - t.Fatalf("parseAddURLInput error: %v", err) + t.Fatalf("placeholderOIDForUnknownSHA: %v", err) } - if in.sourceArg != "gs://bucket/path/to/file.bin" { - t.Fatalf("unexpected source url: %s", in.sourceArg) + oid2, err := placeholderOIDForUnknownSHA(`"etag-abc"`, "s3://bucket/key") + if err != nil { + t.Fatalf("placeholderOIDForUnknownSHA quoted: %v", err) } - if in.path != "path/to/file.bin" { - t.Fatalf("unexpected path: %s", in.path) + if oid1 != oid2 { + t.Fatalf("expected trimmed etag handling to be stable: %s vs %s", oid1, oid2) + } + if len(oid1) != 64 { + t.Fatalf("expected 64-char oid, got %q", oid1) + } + if _, err := placeholderOIDForUnknownSHA("", "s3://bucket/key"); err == nil { + t.Fatal("expected empty etag error") } } -func TestParseAddURLInput_PassesS3EnvHints(t *testing.T) { - t.Setenv("TEST_BUCKET_REGION", "us-east-1") - t.Setenv("TEST_BUCKET_ENDPOINT", "https://aced-storage.ohsu.edu") - t.Setenv("TEST_BUCKET_ACCESS_KEY", "cbds-user") - t.Setenv("TEST_BUCKET_SECRET_KEY", "cbds-secret") - +func TestParseAddURLInput_DoesNotRequireAWSFlags(t *testing.T) { cmd := NewCommand() - in, err := parseAddURLInput(cmd, []string{"s3://cbds/path/to/file.bin"}) + in, err := parseAddURLInput(cmd, []string{"gs://bucket/path/to/file.bin"}) if err != nil { t.Fatalf("parseAddURLInput error: %v", err) } - params := buildObjectParameters("s3://cbds/path/to/file.bin", in.path, in.sha256) - if params.S3Region != "us-east-1" { - t.Fatalf("unexpected S3Region: %s", params.S3Region) - } - if params.S3Endpoint != "https://aced-storage.ohsu.edu" { - t.Fatalf("unexpected S3Endpoint: %s", params.S3Endpoint) - } - if params.S3AccessKey != "cbds-user" { - t.Fatalf("unexpected S3AccessKey: %s", params.S3AccessKey) + if in.sourceArg != "gs://bucket/path/to/file.bin" { + t.Fatalf("unexpected source url: %s", in.sourceArg) } - if params.S3SecretKey != "cbds-secret" { - t.Fatalf("unexpected S3SecretKey: %s", params.S3SecretKey) + if in.path != "path/to/file.bin" { + t.Fatalf("unexpected path: %s", in.path) } } @@ -204,38 +200,6 @@ func TestParseAddURLInput_ObjectKeyModeDefaultsPathToKey(t *testing.T) { } } -func TestResolveObjectURL_UsesConfiguredBucketScopeForObjectKeyMode(t *testing.T) { - input := addURLInput{ - sourceArg: "nested/path/file.bin", - scheme: "s3", - } - scope := gitrepo.ResolvedBucketScope{ - Bucket: "mapped-bucket", - Prefix: "mapped/prefix", - } - - got, err := resolveObjectURL(input, scope) - if err != nil { - t.Fatalf("resolveObjectURL: %v", err) - } - if got != "s3://mapped-bucket/mapped/prefix/nested/path/file.bin" { - t.Fatalf("unexpected object URL: %s", got) - } -} - -func TestResolveObjectURL_RejectsObjectKeyModeWithoutScheme(t *testing.T) { - _, err := resolveObjectURL(addURLInput{sourceArg: "nested/path/file.bin"}, gitrepo.ResolvedBucketScope{ - Bucket: "mapped-bucket", - Prefix: "mapped/prefix", - }) - if err == nil { - t.Fatal("expected error") - } - if !strings.Contains(err.Error(), "requires --scheme") { - t.Fatalf("unexpected error: %v", err) - } -} - func TestUpdatePrecommitCacheWritesEntries(t *testing.T) { repo := setupGitRepo(t) path := filepath.Join(repo, "data", "file.bin") @@ -280,8 +244,7 @@ func TestUpdatePrecommitCacheWritesEntries(t *testing.T) { t.Fatalf("expected updated_at to be set") } - oidSum := sha256.Sum256([]byte(oid)) - oidEntryFile := filepath.Join(oidDir, fmt.Sprintf("%x.json", oidSum[:])) + oidEntryFile := precommit_cache.OIDEntryPath(&precommit_cache.Cache{OIDsDir: oidDir}, oid) oidData, err := os.ReadFile(oidEntryFile) if err != nil { t.Fatalf("read oid entry: %v", err) @@ -331,8 +294,8 @@ func TestUpdatePrecommitCacheContentChanged(t *testing.T) { cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") oidDir := filepath.Join(cacheRoot, "oids") - firstSum := sha256.Sum256([]byte(firstOID)) - firstEntryFile := filepath.Join(oidDir, fmt.Sprintf("%x.json", firstSum[:])) + cache := &precommit_cache.Cache{OIDsDir: oidDir} + firstEntryFile := precommit_cache.OIDEntryPath(cache, firstOID) firstData, err := os.ReadFile(firstEntryFile) if err != nil { t.Fatalf("read first oid entry: %v", err) @@ -345,8 +308,7 @@ func TestUpdatePrecommitCacheContentChanged(t *testing.T) { t.Fatalf("expected old oid entry paths to be empty, got %v", firstEntry.Paths) } - secondSum := sha256.Sum256([]byte(secondOID)) - secondEntryFile := filepath.Join(oidDir, fmt.Sprintf("%x.json", secondSum[:])) + secondEntryFile := precommit_cache.OIDEntryPath(cache, secondOID) secondData, err := os.ReadFile(secondEntryFile) if err != nil { t.Fatalf("read second oid entry: %v", err) @@ -363,26 +325,25 @@ func TestUpdatePrecommitCacheContentChanged(t *testing.T) { } } -// deprecated test case: now that we always "trust" the client-provided SHA256, this case is not applicable -//func TestRunAddURL_SHA256Mismatch(t *testing.T) { -// ... -//} - func stubAddURLDeps( t *testing.T, service *AddURLService, - inspectFn func(context.Context, sycloud.ObjectParameters) (*sycloud.ObjectInfo, error), + inspectFn func(context.Context, *remoteruntime.GitContext, addURLInput) (*inspectedObject, error), + getRemoteClientFn func(*config.Config, config.Remote, *slog.Logger) (*remoteruntime.GitContext, error), isTrackedFn func(string) (bool, error), ) func() { t.Helper() - origInspect := service.inspectObject + origInspect := service.inspectRemoteObject + origGetRemoteClient := service.getRemoteClient origIsTracked := service.isLFSTracked - service.inspectObject = inspectFn + service.inspectRemoteObject = inspectFn + service.getRemoteClient = getRemoteClientFn service.isLFSTracked = isTrackedFn return func() { - service.inspectObject = origInspect + service.inspectRemoteObject = origInspect + service.getRemoteClient = origGetRemoteClient service.isLFSTracked = origIsTracked } } diff --git a/cmd/addurl/remote_inspect_test.go b/cmd/addurl/remote_inspect_test.go new file mode 100644 index 00000000..21d5cf45 --- /dev/null +++ b/cmd/addurl/remote_inspect_test.go @@ -0,0 +1,33 @@ +package addurl + +import ( + "net/http" + "strings" + "testing" + + syrequest "github.com/calypr/syfon/client/request" +) + +func TestMapInspectError_UpgradeMessageForMissingRoute(t *testing.T) { + err := mapInspectError("s3://bucket/key", &syrequest.ResponseError{ + Method: http.MethodPost, + URL: "https://example.test/data/inspect", + Status: http.StatusNotFound, + Body: "Cannot POST /data/inspect", + }) + if err == nil || !strings.Contains(err.Error(), "upgrade Syfon") { + t.Fatalf("expected upgrade message, got %v", err) + } +} + +func TestMapInspectError_ActionableForbidden(t *testing.T) { + err := mapInspectError("s3://bucket/key", &syrequest.ResponseError{ + Method: http.MethodPost, + URL: "https://example.test/data/inspect", + Status: http.StatusForbidden, + Body: "provider denied access to s3://bucket/key", + }) + if err == nil || !strings.Contains(err.Error(), "was denied") { + t.Fatalf("expected denied message, got %v", err) + } +} diff --git a/cmd/addurl/run.go b/cmd/addurl/run.go new file mode 100644 index 00000000..4c440557 --- /dev/null +++ b/cmd/addurl/run.go @@ -0,0 +1,157 @@ +package addurl + +import ( + "context" + "fmt" + "log/slog" + "strings" + + "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/drslog" + "github.com/calypr/git-drs/internal/gitrepo" + "github.com/calypr/git-drs/internal/lfs" + "github.com/calypr/git-drs/internal/remoteruntime" + sycloud "github.com/calypr/syfon/client/cloud" + "github.com/spf13/cobra" +) + +// AddURLService groups injectable dependencies used to implement the add-url +// behavior (logger factory, object inspection, LFS helpers, config loader, etc.). +type AddURLService struct { + newLogger func(string, bool) (*slog.Logger, error) + inspectRemoteObject func(ctx context.Context, drsCtx *remoteruntime.GitContext, input addURLInput) (*inspectedObject, error) + getRemoteClient func(cfg *config.Config, remote config.Remote, logger *slog.Logger) (*remoteruntime.GitContext, error) + isLFSTracked func(path string) (bool, error) + getGitRoots func(ctx context.Context) (string, string, error) + gitLFSTrack func(ctx context.Context, path string) (bool, error) + loadConfig func() (*config.Config, error) +} + +// NewAddURLService constructs an AddURLService populated with production +// implementations of its dependencies. +func NewAddURLService() *AddURLService { + return &AddURLService{ + newLogger: drslog.NewLogger, + inspectRemoteObject: inspectRemoteObjectViaServer, + getRemoteClient: func(cfg *config.Config, remote config.Remote, logger *slog.Logger) (*remoteruntime.GitContext, error) { + return remoteruntime.New(cfg, remote, logger) + }, + isLFSTracked: lfs.IsLFSTracked, + getGitRoots: lfs.GetGitRootDirectories, + gitLFSTrack: gitrepo.TrackReadOnly, + loadConfig: config.LoadConfig, + } +} + +// Run executes the add-url workflow: parse CLI input, inspect the provider +// object through the configured Syfon remote, ensure the LFS object exists in +// local storage, write a pointer file, update the pre-commit cache +// (best-effort), optionally add a tracking entry, and record the DRS mapping. +func (s *AddURLService) Run(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + if ctx == nil { + ctx = context.Background() + } + + logger, err := s.newLogger("", false) + if err != nil { + return fmt.Errorf("error creating logger: %v", err) + } + + input, err := parseAddURLInput(cmd, args) + if err != nil { + return err + } + + cfg, err := s.loadConfig() + if err != nil { + return fmt.Errorf("error getting config: %v", err) + } + + remote, err := cfg.GetDefaultRemote() + if err != nil { + return err + } + + remoteConfig := cfg.GetRemote(remote) + if remoteConfig == nil { + return fmt.Errorf("error getting remote configuration for %s", remote) + } + + drsCtx, err := s.getRemoteClient(cfg, remote, logger) + if err != nil { + return err + } + + org, project, scope, err := resolveTargetScope(remoteConfig) + if err != nil { + return err + } + + if drsCtx != nil { + org = firstNonEmpty(strings.TrimSpace(drsCtx.Organization), org) + project = firstNonEmpty(strings.TrimSpace(drsCtx.ProjectId), project) + scope = gitrepo.ResolvedBucketScope{ + Bucket: firstNonEmpty(strings.TrimSpace(drsCtx.BucketName), scope.Bucket), + Prefix: firstNonEmpty(strings.TrimSpace(drsCtx.StoragePrefix), scope.Prefix), + } + } + + inspected, err := s.inspectRemoteObject(ctx, drsCtx, input) + if err != nil { + return err + } + input.objectURL = inspected.objectURL + objectInfo := inspected.info + + isTracked, err := s.isLFSTracked(input.path) + if err != nil { + return fmt.Errorf("check LFS tracking for %s: %w", input.path, err) + } + + gitCommonDir, lfsRoot, err := s.getGitRoots(ctx) + if err != nil { + return fmt.Errorf("get git root directories: %w", err) + } + + if err := printResolvedInfo(cmd, gitCommonDir, lfsRoot, objectInfo, input.path, isTracked, input.sha256); err != nil { + return err + } + + oid, err := s.ensureLFSObject(ctx, objectInfo, input, lfsRoot) + if err != nil { + return err + } + + if err := writePointerFile(input.path, oid, objectInfo.SizeBytes); err != nil { + return err + } + + if err := updatePrecommitCache(ctx, logger, input.path, oid, input.objectURL); err != nil { + logger.Warn("pre-commit cache update skipped", "error", err) + } + + if err := maybeTrackLFS(ctx, s.gitLFSTrack, input.path, isTracked); err != nil { + return err + } + + builder := drsobjectBuilder(scope.Bucket, org, project, scope.Prefix) + file := addURLDrsFile{ + Name: input.path, + Size: objectInfo.SizeBytes, + Oid: oid, + } + if _, err := writeAddURLDrsObject(builder, file, input.objectURL); err != nil { + return fmt.Errorf("write local DRS object: %w", err) + } + + return nil +} + +func (s *AddURLService) ensureLFSObject(_ context.Context, objectInfo *sycloud.ObjectInfo, input addURLInput, _ string) (string, error) { + if input.sha256 != "" { + return input.sha256, nil + } + + return placeholderOIDForUnknownSHA(objectInfo.ETag, input.objectURL) +} diff --git a/cmd/addurl/scope_test.go b/cmd/addurl/scope_test.go index 9650bf5f..1a793a49 100644 --- a/cmd/addurl/scope_test.go +++ b/cmd/addurl/scope_test.go @@ -5,9 +5,6 @@ import ( "os/exec" "testing" - "log/slog" - - "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/gitrepo" ) @@ -29,9 +26,6 @@ func (f fakeRemote) GetBucketName() string { func (f fakeRemote) GetStoragePrefix() string { return f.prefix } -func (f fakeRemote) GetClient(string, *slog.Logger) (*config.GitContext, error) { - return nil, nil -} func TestResolveTargetScope_DefaultFallsBackToRemoteConfig(t *testing.T) { remote := fakeRemote{ diff --git a/cmd/addurl/service.go b/cmd/addurl/service.go deleted file mode 100644 index 79ad6195..00000000 --- a/cmd/addurl/service.go +++ /dev/null @@ -1,211 +0,0 @@ -package addurl - -import ( - "context" - "fmt" - "log/slog" - "os" - - "github.com/calypr/git-drs/internal/common" - "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/drslog" - "github.com/calypr/git-drs/internal/drsobject" - "github.com/calypr/git-drs/internal/drstrack" - "github.com/calypr/git-drs/internal/lfs" - drsapi "github.com/calypr/syfon/apigen/client/drs" - sycloud "github.com/calypr/syfon/client/cloud" - "github.com/google/uuid" - "github.com/spf13/cobra" -) - -// AddURLService groups injectable dependencies used to implement the add-url -// behavior (logger factory, object inspection, LFS helpers, config loader, etc.). -type AddURLService struct { - newLogger func(string, bool) (*slog.Logger, error) - inspectObject func(ctx context.Context, input sycloud.ObjectParameters) (*sycloud.ObjectInfo, error) - isLFSTracked func(path string) (bool, error) - getGitRoots func(ctx context.Context) (string, string, error) - gitLFSTrack func(ctx context.Context, path string) (bool, error) - loadConfig func() (*config.Config, error) -} - -// NewAddURLService constructs an AddURLService populated with production -// implementations of its dependencies. -func NewAddURLService() *AddURLService { - return &AddURLService{ - newLogger: drslog.NewLogger, - inspectObject: sycloud.InspectObject, - isLFSTracked: lfs.IsLFSTracked, - getGitRoots: lfs.GetGitRootDirectories, - gitLFSTrack: drstrack.TrackReadOnly, - loadConfig: config.LoadConfig, - } -} - -// Run executes the add-url workflow: parse CLI input, resolve the target bucket -// scope, inspect the provider object through the client-owned cloud package, -// ensure the LFS object exists in local storage, write a pointer file, update -// the pre-commit cache (best-effort), optionally add a tracking entry, and -// record the DRS mapping. -func (s *AddURLService) Run(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - if ctx == nil { - ctx = context.Background() - } - - logger, err := s.newLogger("", false) - if err != nil { - return fmt.Errorf("error creating logger: %v", err) - } - - input, err := parseAddURLInput(cmd, args) - if err != nil { - return err - } - - cfg, err := s.loadConfig() - if err != nil { - return fmt.Errorf("error getting config: %v", err) - } - - remote, err := cfg.GetDefaultRemote() - if err != nil { - return err - } - - remoteConfig := cfg.GetRemote(remote) - if remoteConfig == nil { - return fmt.Errorf("error getting remote configuration for %s", remote) - } - - org, project, scope, err := resolveTargetScope(remoteConfig) - if err != nil { - return err - } - - input.objectURL, err = resolveObjectURL(input, scope) - if err != nil { - return err - } - - objectInfo, err := s.inspectObject(ctx, buildObjectParameters(input.objectURL, input.path, input.sha256)) - if err != nil { - return err - } - - isTracked, err := s.isLFSTracked(input.path) - if err != nil { - return fmt.Errorf("check LFS tracking for %s: %w", input.path, err) - } - - gitCommonDir, lfsRoot, err := s.getGitRoots(ctx) - if err != nil { - return fmt.Errorf("get git root directories: %w", err) - } - - if err := printResolvedInfo(cmd, gitCommonDir, lfsRoot, objectInfo, input.path, isTracked, input.sha256); err != nil { - return err - } - - oid, err := s.ensureLFSObject(ctx, objectInfo, input, lfsRoot) - if err != nil { - return err - } - - if err := writePointerFile(input.path, oid, objectInfo.SizeBytes); err != nil { - return err - } - - if err := updatePrecommitCache(ctx, logger, input.path, oid, input.objectURL); err != nil { - logger.Warn("pre-commit cache update skipped", "error", err) - } - - if err := maybeTrackLFS(ctx, s.gitLFSTrack, input.path, isTracked); err != nil { - return err - } - - builder := drsobject.NewBuilder(scope.Bucket, project) - builder.Organization = org - builder.StoragePrefix = scope.Prefix - - file := addURLDrsFile{ - Name: input.path, - Size: objectInfo.SizeBytes, - Oid: oid, - } - if _, err := writeAddURLDrsObject(builder, file, input.objectURL); err != nil { - return fmt.Errorf("write local DRS object: %w", err) - } - - return nil -} - -type addURLDrsFile struct { - Name string - Size int64 - Oid string -} - -func writeAddURLDrsObject(builder drsobject.Builder, file addURLDrsFile, objectPath string) (*drsapi.DrsObject, error) { - existing, err := drsobject.ReadObject(common.DRS_OBJS_PATH, file.Oid) - var drsObj *drsapi.DrsObject - if err == nil && existing != nil { - drsObj = existing - name := file.Name - drsObj.Name = &name - drsObj.Size = file.Size - } else { - drsID := uuid.NewSHA1(drsobject.UUIDNamespace, []byte(fmt.Sprintf("%s:%s", builder.Project, drsobject.NormalizeOid(file.Oid)))).String() - drsObj, err = builder.Build(file.Name, file.Oid, file.Size, drsID) - if err != nil { - return nil, fmt.Errorf("error building DRS object for oid %s: %w", file.Oid, err) - } - } - - if objectPath != "" { - if drsObj.AccessMethods != nil && len(*drsObj.AccessMethods) > 0 { - am := &(*drsObj.AccessMethods)[0] - am.AccessUrl = &struct { - Headers *[]string `json:"headers,omitempty"` - Url string `json:"url"` - }{Url: objectPath} - } else { - drsObj.AccessMethods = &[]drsapi.AccessMethod{{ - Type: drsapi.AccessMethodTypeS3, - AccessUrl: &struct { - Headers *[]string `json:"headers,omitempty"` - Url string `json:"url"` - }{Url: objectPath}, - }} - } - } - - if err := drsobject.WriteObject(common.DRS_OBJS_PATH, drsObj, file.Oid); err != nil { - return nil, fmt.Errorf("error writing DRS object for oid %s: %w", file.Oid, err) - } - return drsObj, nil -} - -// ensureLFSObject ensures the LFS object identified by objectInfo exists in the -// repository's LFS storage. If SHA256 is provided, it is trusted and returned. -// Otherwise we create a sentinel object and synthetic OID derived from ETag, -// deferring true checksum validation to first real data use. -func (s *AddURLService) ensureLFSObject(ctx context.Context, objectInfo *sycloud.ObjectInfo, input addURLInput, lfsRoot string) (string, error) { - _ = ctx - if input.sha256 != "" { - return input.sha256, nil - } - - oid, err := lfs.SyntheticOIDFromETag(objectInfo.ETag) - if err != nil { - return "", err - } - objPath, err := lfs.WriteAddURLSentinelObject(lfsRoot, oid, objectInfo.ETag, input.objectURL) - if err != nil { - return "", err - } - if _, err := fmt.Fprintf(os.Stderr, "Added add-url sentinel object at %s\n", objPath); err != nil { - return "", fmt.Errorf("stderr write: %w", err) - } - return oid, nil -} diff --git a/cmd/bucket/main.go b/cmd/bucket/main.go index 745c67be..70da858c 100644 --- a/cmd/bucket/main.go +++ b/cmd/bucket/main.go @@ -11,11 +11,11 @@ import ( "strings" "time" - "github.com/calypr/data-client/credentials" - "github.com/calypr/git-drs/internal/common" + conf "github.com/calypr/calypr-cli/conf" + "github.com/calypr/calypr-cli/credentials" + "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/gitrepo" - conf "github.com/calypr/syfon/client/config" "github.com/spf13/cobra" ) @@ -249,8 +249,13 @@ func resolveEndpointAndToken(remoteName string) (string, string, error) { if prof, err := configure.Load(remoteName); err == nil { token = strings.TrimSpace(prof.AccessToken) if token == "" { - if ensureErr := credentials.EnsureValidCredential(context.Background(), prof, drslog.GetLogger()); ensureErr == nil { - _ = configure.Save(prof) + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + ensureErr := credentials.EnsureValidCredential(ctx, prof, drslog.GetLogger()) + cancel() + if ensureErr == nil { + if err := configure.Save(prof); err != nil { + return "", "", fmt.Errorf("failed to save refreshed credential for remote %q: %w", remoteName, err) + } token = strings.TrimSpace(prof.AccessToken) } } @@ -270,7 +275,7 @@ func resolveEndpointAndToken(remoteName string) (string, string, error) { endpoint = strings.TrimSpace(endpoint) } if endpoint == "" { - parsed, err := common.ParseAPIEndpointFromToken(token) + parsed, err := config.ParseAPIEndpointFromToken(token) if err != nil { return "", "", fmt.Errorf("unable to resolve API endpoint from token: %w", err) } diff --git a/cmd/clean/main.go b/cmd/clean/main.go index 0f0c701a..fb895750 100644 --- a/cmd/clean/main.go +++ b/cmd/clean/main.go @@ -16,8 +16,8 @@ import ( "fmt" "os" - "github.com/calypr/git-drs/internal/drsfilter" "github.com/calypr/git-drs/internal/drslog" + internalfilter "github.com/calypr/git-drs/internal/filter" "github.com/calypr/git-drs/internal/lfs" "github.com/spf13/cobra" ) @@ -54,5 +54,5 @@ func runClean(cmd *cobra.Command, args []string) error { return fmt.Errorf("clean: resolve LFS root: %w", err) } - return drsfilter.CleanContent(ctx, lfsRoot, pathname, os.Stdin, os.Stdout, logger) + return internalfilter.CleanContent(ctx, lfsRoot, pathname, os.Stdin, os.Stdout, logger) } diff --git a/cmd/copyrecords/api.go b/cmd/copyrecords/api.go new file mode 100644 index 00000000..1280389c --- /dev/null +++ b/cmd/copyrecords/api.go @@ -0,0 +1,121 @@ +package copyrecords + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + + drsapi "github.com/calypr/syfon/apigen/client/drs" + "github.com/calypr/syfon/client/request" + syservices "github.com/calypr/syfon/client/services" +) + +type copyHashInfo map[string]string + +type copyRecord struct { + AccessMethods *[]drsapi.AccessMethod `json:"access_methods,omitempty"` + ControlledAccess *[]string `json:"controlled_access,omitempty"` + CreatedTime *string `json:"created_time,omitempty"` + Description *string `json:"description,omitempty"` + Did string `json:"did"` + FileName *string `json:"file_name,omitempty"` + Name *string `json:"name,omitempty"` + Hashes *copyHashInfo `json:"hashes,omitempty"` + Organization *string `json:"organization,omitempty"` + Project *string `json:"project,omitempty"` + Size *int64 `json:"size,omitempty"` + UpdatedTime *string `json:"updated_time,omitempty"` + Version *string `json:"version,omitempty"` +} + +type copyListRecordsResponse struct { + Records *[]copyRecord `json:"records,omitempty"` +} + +type copyBulkCreateRequest struct { + Records []copyRecord `json:"records"` +} + +type copyBulkHashesRequest struct { + Hashes []string `json:"hashes"` +} + +type copyBulkHashesResponse struct { + Results map[string][]copyRecord `json:"results,omitempty"` +} + +type indexAPI interface { + List(ctx context.Context, opts syservices.ListRecordsOptions) (copyListRecordsResponse, error) + BulkDocuments(ctx context.Context, dids []string) ([]copyRecord, error) + BulkHashes(ctx context.Context, hashes []string) (copyBulkHashesResponse, error) + CreateBulk(ctx context.Context, req copyBulkCreateRequest) (copyListRecordsResponse, error) +} + +type rawIndexAPI struct { + requestor request.Requester +} + +func newRawIndexAPI(requestor request.Requester) *rawIndexAPI { + return &rawIndexAPI{requestor: requestor} +} + +func (r *rawIndexAPI) List(ctx context.Context, opts syservices.ListRecordsOptions) (copyListRecordsResponse, error) { + params := url.Values{} + if opts.Hash != "" { + params.Set("hash", opts.Hash) + } + if opts.URL != "" { + params.Set("url", opts.URL) + } + if opts.Organization != "" { + params.Set("organization", opts.Organization) + } + if opts.ProjectID != "" { + params.Set("project", opts.ProjectID) + } + if opts.Limit != 0 { + params.Set("limit", fmt.Sprintf("%d", opts.Limit)) + } + if opts.Page != 0 { + params.Set("page", fmt.Sprintf("%d", opts.Page)) + } + var out copyListRecordsResponse + if err := r.requestor.Do(ctx, http.MethodGet, "/index", nil, &out, request.WithQueryValues(params)); err != nil { + return copyListRecordsResponse{}, err + } + return out, nil +} + +func (r *rawIndexAPI) BulkDocuments(ctx context.Context, dids []string) ([]copyRecord, error) { + var out []copyRecord + if err := r.requestor.Do(ctx, http.MethodPost, "/index/bulk/documents", dids, &out); err != nil { + return nil, err + } + return out, nil +} + +func (r *rawIndexAPI) BulkHashes(ctx context.Context, hashes []string) (copyBulkHashesResponse, error) { + var out copyBulkHashesResponse + if err := r.requestor.Do(ctx, http.MethodPost, "/index/bulk/hashes", copyBulkHashesRequest{Hashes: hashes}, &out); err != nil { + return copyBulkHashesResponse{}, err + } + return out, nil +} + +func (r *rawIndexAPI) CreateBulk(ctx context.Context, req copyBulkCreateRequest) (copyListRecordsResponse, error) { + var out copyListRecordsResponse + if err := r.requestor.Do(ctx, http.MethodPost, "/index/bulk", req, &out); err != nil { + return copyListRecordsResponse{}, err + } + return out, nil +} + +func canonicalAccessMethod(method drsapi.AccessMethod) string { + b, err := json.Marshal(method) + if err != nil { + return fmt.Sprintf("%s|%v", method.Type, method.AccessId) + } + return string(b) +} diff --git a/cmd/copyrecords/main.go b/cmd/copyrecords/main.go new file mode 100644 index 00000000..4445d1df --- /dev/null +++ b/cmd/copyrecords/main.go @@ -0,0 +1,105 @@ +package copyrecords + +import ( + "fmt" + "strings" + + "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/drslog" + "github.com/calypr/git-drs/internal/remoteruntime" + "github.com/spf13/cobra" +) + +var ( + batchSize int + overwriteNameFileName bool +) + +var Cmd = &cobra.Command{ + Use: "copy-records [source-remote] ", + Short: "Copy Syfon records between remotes for one organization/project scope", + Long: "Read all Syfon records for a source organization/project scope and bulk load them into a target Syfon instance, only merging controlled_access and access_methods for records that already exist on the target.", + Args: cobra.RangeArgs(2, 3), + RunE: func(cmd *cobra.Command, args []string) error { + logger := drslog.GetLogger() + cfg, err := config.LoadConfig() + if err != nil { + return fmt.Errorf("error loading config: %w", err) + } + + sourceRemote := "" + targetRemote := "" + scopeArg := "" + if len(args) == 2 { + targetRemote = args[0] + scopeArg = args[1] + } else { + sourceRemote = args[0] + targetRemote = args[1] + scopeArg = args[2] + } + + srcRemoteName, err := cfg.GetRemoteOrDefault(sourceRemote) + if err != nil { + return fmt.Errorf("error resolving source remote: %w", err) + } + if strings.TrimSpace(targetRemote) == "" { + return fmt.Errorf("target remote is required") + } + dstRemoteName := config.Remote(targetRemote) + if srcRemoteName == dstRemoteName { + return fmt.Errorf("source and target remotes must be different") + } + + srcCfg := cfg.GetRemote(srcRemoteName) + if srcCfg == nil { + return fmt.Errorf("source remote %q not found", srcRemoteName) + } + + org, proj, err := parseScopeArg(scopeArg) + if err != nil { + return err + } + + srcCtx, err := remoteruntime.New(cfg, srcRemoteName, logger) + if err != nil { + return fmt.Errorf("error creating source client: %w", err) + } + dstCtx, err := remoteruntime.New(cfg, dstRemoteName, logger) + if err != nil { + return fmt.Errorf("error creating target client: %w", err) + } + + stats, err := copyProjectRecords( + cmd.Context(), + logger, + newRawIndexAPI(srcCtx.Client.Requestor()), + newRawIndexAPI(dstCtx.Client.Requestor()), + org, + proj, + batchSize, + overwriteNameFileName, + ) + if err != nil { + return err + } + + logger.Info("copy-records complete", + "source_remote", srcRemoteName, + "target_remote", dstRemoteName, + "organization", org, + "project", proj, + "source_seen", stats.SourceSeen, + "created", stats.Created, + "updated", stats.Updated, + "unchanged", stats.Unchanged, + "written", stats.Written, + ) + return nil + }, +} + +func init() { + Cmd.Flags().IntVar(&batchSize, "batch-size", 250, "records per source page and target bulk write") + Cmd.Flags().BoolVar(&overwriteNameFileName, "overwrite-name-file-name", false, "for existing target records, replace target name and file_name with the source values") +} diff --git a/cmd/copyrecords/main_test.go b/cmd/copyrecords/main_test.go new file mode 100644 index 00000000..e5718900 --- /dev/null +++ b/cmd/copyrecords/main_test.go @@ -0,0 +1,331 @@ +package copyrecords + +import ( + "context" + "testing" + + drsapi "github.com/calypr/syfon/apigen/client/drs" + syservices "github.com/calypr/syfon/client/services" +) + +type fakeIndexAPI struct { + listResp copyListRecordsResponse + listFn func(opts syservices.ListRecordsOptions) copyListRecordsResponse + bulkDocsResp []copyRecord + bulkHashResp copyBulkHashesResponse + createBulkReq []copyBulkCreateRequest +} + +func (f *fakeIndexAPI) List(ctx context.Context, opts syservices.ListRecordsOptions) (copyListRecordsResponse, error) { + if f.listFn != nil { + return f.listFn(opts), nil + } + return f.listResp, nil +} + +func (f *fakeIndexAPI) BulkDocuments(ctx context.Context, dids []string) ([]copyRecord, error) { + return f.bulkDocsResp, nil +} + +func (f *fakeIndexAPI) BulkHashes(ctx context.Context, hashes []string) (copyBulkHashesResponse, error) { + return f.bulkHashResp, nil +} + +func (f *fakeIndexAPI) CreateBulk(ctx context.Context, req copyBulkCreateRequest) (copyListRecordsResponse, error) { + f.createBulkReq = append(f.createBulkReq, req) + return copyListRecordsResponse{Records: &req.Records}, nil +} + +func TestMergeExistingRecord_UnionsControlledAccessAndAccessMethodsOnly(t *testing.T) { + dstName := "target.bin" + srcName := "source.bin" + desc := "keep target description" + leftCA := []string{"/organization/A/project/P1"} + rightCA := []string{"/organization/A/project/P1", "/organization/A/project/P2"} + leftMethods := []drsapi.AccessMethod{{ + Type: drsapi.AccessMethodTypeS3, + AccessUrl: &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: "s3://bucket/one"}, + }} + rightMethods := []drsapi.AccessMethod{ + leftMethods[0], + { + Type: drsapi.AccessMethodTypeHttps, + AccessUrl: &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: "https://example.org/two"}, + }, + } + + merged, changed := mergeExistingRecord( + copyRecord{ + Did: "did-1", + FileName: &dstName, + Description: &desc, + ControlledAccess: &leftCA, + AccessMethods: &leftMethods, + }, + copyRecord{ + Did: "did-1", + FileName: &srcName, + ControlledAccess: &rightCA, + AccessMethods: &rightMethods, + }, + false, + ) + + if !changed { + t.Fatalf("expected merge to report a change") + } + if merged.FileName == nil || *merged.FileName != dstName { + t.Fatalf("expected target metadata to be preserved, got %+v", merged.FileName) + } + if merged.Description == nil || *merged.Description != desc { + t.Fatalf("expected target description to be preserved") + } + if merged.ControlledAccess == nil || len(*merged.ControlledAccess) != 2 { + t.Fatalf("expected merged controlled access union, got %+v", merged.ControlledAccess) + } + if merged.AccessMethods == nil || len(*merged.AccessMethods) != 2 { + t.Fatalf("expected merged access method union, got %+v", merged.AccessMethods) + } +} + +func TestMergeExistingRecord_OverwritesFileNameWhenFlagEnabled(t *testing.T) { + dstName := "target.bin" + dstDisplayName := "target-display" + srcName := "nested/source.bin" + srcDisplayName := "source-display" + + merged, changed := mergeExistingRecord( + copyRecord{ + Did: "did-1", + FileName: &dstName, + Name: &dstDisplayName, + }, + copyRecord{ + Did: "did-1", + FileName: &srcName, + Name: &srcDisplayName, + }, + true, + ) + + if !changed { + t.Fatalf("expected merge to report a change") + } + if merged.FileName == nil || *merged.FileName != srcName { + t.Fatalf("expected source file_name to win, got %+v", merged.FileName) + } + if merged.Name == nil || *merged.Name != srcDisplayName { + t.Fatalf("expected source name to win, got %+v", merged.Name) + } +} + +func TestBuildMergedBatch_CreatesNewAndUpdatesExisting(t *testing.T) { + srcCA := []string{"/organization/A/project/P1"} + newCA := []string{"/organization/A/project/P2"} + existingHash := copyHashInfo{"sha256": "sha-existing"} + newHash := copyHashInfo{"sha256": "sha-new"} + srcMethods := []drsapi.AccessMethod{{ + Type: drsapi.AccessMethodTypeS3, + AccessUrl: &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: "s3://bucket/a"}, + }} + newMethods := []drsapi.AccessMethod{{ + Type: drsapi.AccessMethodTypeHttps, + AccessUrl: &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: "https://example.org/b"}, + }} + + target := &fakeIndexAPI{ + bulkDocsResp: []copyRecord{ + { + Did: "did-existing", + Hashes: &existingHash, + ControlledAccess: &srcCA, + AccessMethods: &srcMethods, + }, + }, + bulkHashResp: copyBulkHashesResponse{ + Results: map[string][]copyRecord{ + "sha256:sha-existing": {{ + Did: "did-existing", + Hashes: &existingHash, + ControlledAccess: &srcCA, + AccessMethods: &srcMethods, + }}, + }, + }, + } + + source := []copyRecord{ + { + Did: "did-existing", + Hashes: &existingHash, + ControlledAccess: &newCA, + AccessMethods: &newMethods, + }, + { + Did: "did-new", + Hashes: &newHash, + ControlledAccess: &srcCA, + AccessMethods: &srcMethods, + }, + } + + out, stats, err := buildMergedBatch(context.Background(), target, source, false) + if err != nil { + t.Fatalf("buildMergedBatch error: %v", err) + } + if len(out) != 2 { + t.Fatalf("expected 2 output records, got %d", len(out)) + } + if stats.Created != 1 || stats.Updated != 1 || stats.Unchanged != 0 { + t.Fatalf("unexpected stats: %+v", stats) + } +} + +func TestBuildMergedBatch_MergesIntoExistingChecksumSiblingWhenDIDDiffers(t *testing.T) { + srcHash := copyHashInfo{"sha256": "same-sha"} + dstCA := []string{"/organization/A/project/P1"} + srcCA := []string{"/organization/A/project/P2"} + dstMethods := []drsapi.AccessMethod{{ + Type: drsapi.AccessMethodTypeS3, + AccessUrl: &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: "s3://bucket/existing"}, + }} + srcMethods := []drsapi.AccessMethod{{ + Type: drsapi.AccessMethodTypeHttps, + AccessUrl: &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: "https://example.org/copied"}, + }} + + target := &fakeIndexAPI{ + bulkDocsResp: nil, + bulkHashResp: copyBulkHashesResponse{ + Results: map[string][]copyRecord{ + "sha256:same-sha": {{ + Did: "did-target", + Hashes: &srcHash, + ControlledAccess: &dstCA, + AccessMethods: &dstMethods, + }}, + }, + }, + } + source := []copyRecord{{ + Did: "did-source", + Hashes: &srcHash, + ControlledAccess: &srcCA, + AccessMethods: &srcMethods, + }} + + out, stats, err := buildMergedBatch(context.Background(), target, source, false) + if err != nil { + t.Fatalf("buildMergedBatch error: %v", err) + } + if len(out) != 1 { + t.Fatalf("expected one merged output record, got %d", len(out)) + } + if out[0].Did != "did-target" { + t.Fatalf("expected checksum sibling DID to be preserved, got %q", out[0].Did) + } + if out[0].ControlledAccess == nil || len(*out[0].ControlledAccess) != 2 { + t.Fatalf("expected merged controlled access, got %+v", out[0].ControlledAccess) + } + if out[0].AccessMethods == nil || len(*out[0].AccessMethods) != 2 { + t.Fatalf("expected merged access methods, got %+v", out[0].AccessMethods) + } + if stats.Created != 0 || stats.Updated != 1 || stats.Unchanged != 0 { + t.Fatalf("unexpected stats: %+v", stats) + } +} + +func TestCopyProjectRecords_CopiesAllControlledAccessRecordsWhenScopedListIsEmpty(t *testing.T) { + scopeCA := []string{"/organization/HTAN_INT/project/BForePC"} + source := &fakeIndexAPI{ + listFn: func(opts syservices.ListRecordsOptions) copyListRecordsResponse { + if opts.Organization == "HTAN_INT" && opts.ProjectID == "BForePC" { + return copyListRecordsResponse{Records: &[]copyRecord{}} + } + if opts.Organization == "" && opts.ProjectID == "" && opts.Page == 1 { + return copyListRecordsResponse{Records: &[]copyRecord{ + {Did: "did-in-scope", ControlledAccess: &scopeCA}, + {Did: "did-out-of-scope", ControlledAccess: &[]string{"/organization/OTHER/project/X"}}, + }} + } + return copyListRecordsResponse{Records: &[]copyRecord{}} + }, + } + target := &fakeIndexAPI{} + + stats, err := copyProjectRecords(context.Background(), nil, source, target, "HTAN_INT", "BForePC", 100, false) + if err != nil { + t.Fatalf("copyProjectRecords error: %v", err) + } + if stats.SourceSeen != 1 || stats.Created != 1 || stats.Written != 1 { + t.Fatalf("unexpected stats: %+v", stats) + } + if len(target.createBulkReq) != 1 || len(target.createBulkReq[0].Records) != 1 { + t.Fatalf("expected one created record, got %+v", target.createBulkReq) + } + if target.createBulkReq[0].Records[0].Did != "did-in-scope" { + t.Fatalf("unexpected copied did: %+v", target.createBulkReq[0].Records[0]) + } +} + +func TestCopyProjectRecords_IgnoresPartialScopedListAndCopiesAllControlledAccessRecords(t *testing.T) { + scopeCA := []string{"/organization/HTAN_INT/project/BForePC"} + otherCA := []string{"/organization/OTHER/project/X"} + source := &fakeIndexAPI{ + listFn: func(opts syservices.ListRecordsOptions) copyListRecordsResponse { + if opts.Organization == "HTAN_INT" && opts.ProjectID == "BForePC" { + return copyListRecordsResponse{Records: &[]copyRecord{ + {Did: "did-scoped-partial", ControlledAccess: &scopeCA}, + }} + } + if opts.Organization == "" && opts.ProjectID == "" && opts.Page == 1 { + return copyListRecordsResponse{Records: &[]copyRecord{ + {Did: "did-scoped-partial", ControlledAccess: &scopeCA}, + {Did: "did-missing-from-scoped-list", ControlledAccess: &scopeCA}, + {Did: "did-out-of-scope", ControlledAccess: &otherCA}, + }} + } + return copyListRecordsResponse{Records: &[]copyRecord{}} + }, + } + target := &fakeIndexAPI{} + + stats, err := copyProjectRecords(context.Background(), nil, source, target, "HTAN_INT", "BForePC", 100, false) + if err != nil { + t.Fatalf("copyProjectRecords error: %v", err) + } + if stats.SourceSeen != 2 || stats.Created != 2 || stats.Written != 2 { + t.Fatalf("unexpected stats: %+v", stats) + } + if len(target.createBulkReq) != 1 || len(target.createBulkReq[0].Records) != 2 { + t.Fatalf("expected two created records, got %+v", target.createBulkReq) + } + found := map[string]bool{} + for _, rec := range target.createBulkReq[0].Records { + found[rec.Did] = true + } + if !found["did-scoped-partial"] || !found["did-missing-from-scoped-list"] { + t.Fatalf("missing copied records: %+v", target.createBulkReq[0].Records) + } + if found["did-out-of-scope"] { + t.Fatalf("out-of-scope record should not be copied: %+v", target.createBulkReq[0].Records) + } +} diff --git a/cmd/copyrecords/merge.go b/cmd/copyrecords/merge.go new file mode 100644 index 00000000..85d824d9 --- /dev/null +++ b/cmd/copyrecords/merge.go @@ -0,0 +1,249 @@ +package copyrecords + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + drsapi "github.com/calypr/syfon/apigen/client/drs" +) + +func buildMergedBatch(ctx context.Context, dst indexAPI, source []copyRecord, overwriteNameFileName bool) ([]copyRecord, copyStats, error) { + stats := copyStats{} + if len(source) == 0 { + return nil, stats, nil + } + + dids := make([]string, 0, len(source)) + hashQueries := make([]string, 0, len(source)) + seenHashQueries := make(map[string]struct{}, len(source)) + for _, rec := range source { + did := strings.TrimSpace(rec.Did) + if did == "" { + continue + } + dids = append(dids, did) + if sha := copyRecordSHA256(rec); sha != "" { + query := "sha256:" + sha + if _, ok := seenHashQueries[query]; !ok { + seenHashQueries[query] = struct{}{} + hashQueries = append(hashQueries, query) + } + } + } + + existing, err := dst.BulkDocuments(ctx, dids) + if err != nil { + return nil, stats, fmt.Errorf("target bulk documents failed: %w", err) + } + existingByDID := make(map[string]copyRecord, len(existing)) + for _, rec := range existing { + existingByDID[strings.TrimSpace(rec.Did)] = rec + } + + existingByChecksum := make(map[string][]copyRecord) + if len(hashQueries) > 0 { + hashResp, err := dst.BulkHashes(ctx, hashQueries) + if err != nil { + return nil, stats, fmt.Errorf("target bulk hash lookup failed: %w", err) + } + for _, query := range hashQueries { + sha := strings.TrimSpace(strings.TrimPrefix(query, "sha256:")) + if sha == "" { + continue + } + existingByChecksum[sha] = dedupeCopyRecordsByDID(hashResp.Results[query]) + } + } + + created := make([]copyRecord, 0, len(source)) + pendingUpdates := make(map[string]copyRecord, len(source)) + updateOrder := make([]string, 0, len(source)) + for _, src := range source { + match, found, err := targetRecordForSource(src, existingByDID, existingByChecksum) + if err != nil { + return nil, stats, err + } + if !found { + created = append(created, src) + stats.Created++ + continue + } + + base := match + targetDID := strings.TrimSpace(match.Did) + if pending, ok := pendingUpdates[targetDID]; ok { + base = pending + } + merged, changed := mergeExistingRecord(base, src, overwriteNameFileName) + if changed { + if _, ok := pendingUpdates[targetDID]; !ok { + updateOrder = append(updateOrder, targetDID) + } + pendingUpdates[targetDID] = merged + stats.Updated++ + } else { + stats.Unchanged++ + } + } + + out := make([]copyRecord, 0, len(created)+len(pendingUpdates)) + for _, did := range updateOrder { + out = append(out, pendingUpdates[did]) + } + out = append(out, created...) + return out, stats, nil +} + +func targetRecordForSource(src copyRecord, existingByDID map[string]copyRecord, existingByChecksum map[string][]copyRecord) (copyRecord, bool, error) { + did := strings.TrimSpace(src.Did) + if did == "" { + return copyRecord{}, false, nil + } + if dstRec, ok := existingByDID[did]; ok { + return dstRec, true, nil + } + + sha := copyRecordSHA256(src) + if sha == "" { + return copyRecord{}, false, nil + } + matches := existingByChecksum[sha] + switch len(matches) { + case 0: + return copyRecord{}, false, nil + case 1: + return matches[0], true, nil + default: + dids := make([]string, 0, len(matches)) + for _, match := range matches { + if did := strings.TrimSpace(match.Did); did != "" { + dids = append(dids, did) + } + } + return copyRecord{}, false, fmt.Errorf("target already has multiple records for sha256 %q under different DIDs: %s", sha, strings.Join(dids, ", ")) + } +} + +func copyRecordSHA256(rec copyRecord) string { + if rec.Hashes == nil { + return "" + } + return strings.TrimSpace((*rec.Hashes)["sha256"]) +} + +func dedupeCopyRecordsByDID(records []copyRecord) []copyRecord { + if len(records) == 0 { + return nil + } + out := make([]copyRecord, 0, len(records)) + seen := make(map[string]struct{}, len(records)) + for _, rec := range records { + did := strings.TrimSpace(rec.Did) + if did == "" { + continue + } + if _, ok := seen[did]; ok { + continue + } + seen[did] = struct{}{} + out = append(out, rec) + } + return out +} + +func mergeExistingRecord(dst, src copyRecord, overwriteNameFileName bool) (copyRecord, bool) { + merged := dst + changed := false + + if overwriteNameFileName { + if !equalStringValuePointers(merged.Name, src.Name) { + merged.Name = src.Name + changed = true + } + if !equalStringValuePointers(merged.FileName, src.FileName) { + merged.FileName = src.FileName + changed = true + } + } + + controlledAccess := mergeStringLists(dst.ControlledAccess, src.ControlledAccess) + if !equalStringPointers(merged.ControlledAccess, controlledAccess) { + merged.ControlledAccess = controlledAccess + changed = true + } + + accessMethods := mergeAccessMethods(dst.AccessMethods, src.AccessMethods) + if !equalAccessMethodPointers(merged.AccessMethods, accessMethods) { + merged.AccessMethods = accessMethods + changed = true + } + + return merged, changed +} + +func mergeStringLists(left, right *[]string) *[]string { + seen := map[string]struct{}{} + out := make([]string, 0) + for _, list := range []*[]string{left, right} { + if list == nil { + continue + } + for _, raw := range *list { + val := strings.TrimSpace(raw) + if val == "" { + continue + } + if _, ok := seen[val]; ok { + continue + } + seen[val] = struct{}{} + out = append(out, val) + } + } + if len(out) == 0 { + return nil + } + return &out +} + +func mergeAccessMethods(left, right *[]drsapi.AccessMethod) *[]drsapi.AccessMethod { + seen := map[string]struct{}{} + out := make([]drsapi.AccessMethod, 0) + for _, list := range []*[]drsapi.AccessMethod{left, right} { + if list == nil { + continue + } + for _, method := range *list { + key := canonicalAccessMethod(method) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, method) + } + } + if len(out) == 0 { + return nil + } + return &out +} + +func equalStringPointers(a, b *[]string) bool { + return equalJSON(a, b) +} + +func equalStringValuePointers(a, b *string) bool { + return equalJSON(a, b) +} + +func equalAccessMethodPointers(a, b *[]drsapi.AccessMethod) bool { + return equalJSON(a, b) +} + +func equalJSON(a, b any) bool { + ab, _ := json.Marshal(a) + bb, _ := json.Marshal(b) + return string(ab) == string(bb) +} diff --git a/cmd/copyrecords/run.go b/cmd/copyrecords/run.go new file mode 100644 index 00000000..ddbd747c --- /dev/null +++ b/cmd/copyrecords/run.go @@ -0,0 +1,85 @@ +package copyrecords + +import ( + "context" + "fmt" + "log/slog" + "os" +) + +type copyStats struct { + SourceSeen int + Created int + Updated int + Unchanged int + Written int +} + +func copyProjectRecords(ctx context.Context, logger *slog.Logger, src indexAPI, dst indexAPI, org, project string, batchSize int, overwriteNameFileName bool) (copyStats, error) { + if batchSize <= 0 { + batchSize = 250 + } + + stats := copyStats{} + fmt.Fprintf(os.Stderr, "copy-records: scanning source records for %s/%s\n", org, project) + records, err := listSourceRecordsByControlledAccess(ctx, src, org, project, batchSize) + if err != nil { + return stats, err + } + stats.SourceSeen = len(records) + fmt.Fprintf(os.Stderr, "copy-records: source scan complete, %d records in scope\n", stats.SourceSeen) + + for start := 0; start < len(records); start += batchSize { + end := start + batchSize + if end > len(records) { + end = len(records) + } + + batch := records[start:end] + fmt.Fprintf(os.Stderr, "copy-records: reconciling batch %d-%d of %d\n", start+1, end, len(records)) + toWrite, batchStats, err := buildMergedBatch(ctx, dst, batch, overwriteNameFileName) + if err != nil { + return stats, err + } + stats.Created += batchStats.Created + stats.Updated += batchStats.Updated + stats.Unchanged += batchStats.Unchanged + + if len(toWrite) > 0 { + resp, err := dst.CreateBulk(ctx, copyBulkCreateRequest{Records: toWrite}) + if err != nil { + return stats, fmt.Errorf("target bulk create failed for batch starting at %d: %w", start, err) + } + if resp.Records != nil { + stats.Written += len(*resp.Records) + } else { + stats.Written += len(toWrite) + } + } + fmt.Fprintf( + os.Stderr, + "copy-records: batch %d-%d complete, created=%d updated=%d unchanged=%d written=%d\n", + start+1, + end, + batchStats.Created, + batchStats.Updated, + batchStats.Unchanged, + len(toWrite), + ) + + if logger != nil { + logger.Info("copy-records batch complete", + "organization", org, + "project", project, + "batch_start", start, + "source_records", len(batch), + "created", batchStats.Created, + "updated", batchStats.Updated, + "unchanged", batchStats.Unchanged, + "written", len(toWrite), + ) + } + } + + return stats, nil +} diff --git a/cmd/copyrecords/scan.go b/cmd/copyrecords/scan.go new file mode 100644 index 00000000..e04f13b5 --- /dev/null +++ b/cmd/copyrecords/scan.go @@ -0,0 +1,90 @@ +package copyrecords + +import ( + "context" + "fmt" + "os" + "strings" + + syservices "github.com/calypr/syfon/client/services" + sycommon "github.com/calypr/syfon/common" +) + +func parseScopeArg(raw string) (string, string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", "", fmt.Errorf("scope is required and must be in organization/project form") + } + parts := strings.Split(raw, "/") + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid scope %q: expected organization/project", raw) + } + org := strings.TrimSpace(parts[0]) + project := strings.TrimSpace(parts[1]) + if org == "" || project == "" { + return "", "", fmt.Errorf("invalid scope %q: expected organization/project", raw) + } + return org, project, nil +} + +func listSourceRecordsByControlledAccess(ctx context.Context, src indexAPI, org, project string, batchSize int) ([]copyRecord, error) { + resource, err := sycommon.ResourcePath(org, project) + if err != nil { + return nil, fmt.Errorf("invalid scope %s/%s: %w", org, project, err) + } + if batchSize <= 0 { + batchSize = 250 + } + + page := 1 + out := make([]copyRecord, 0) + seen := map[string]struct{}{} + for { + fmt.Fprintf(os.Stderr, "copy-records: scanning source index page %d, matched-so-far=%d\n", page, len(out)) + listResp, err := src.List(ctx, syservices.ListRecordsOptions{ + Limit: batchSize, + Page: page, + }) + if err != nil { + return nil, fmt.Errorf("fallback source list failed for %s/%s page %d: %w", org, project, page, err) + } + records := []copyRecord{} + if listResp.Records != nil { + records = *listResp.Records + } + if len(records) == 0 { + break + } + for _, rec := range records { + if !recordHasControlledAccess(rec, resource) { + continue + } + did := strings.TrimSpace(rec.Did) + if did == "" { + continue + } + if _, ok := seen[did]; ok { + continue + } + seen[did] = struct{}{} + out = append(out, rec) + } + if len(records) < batchSize { + break + } + page++ + } + return out, nil +} + +func recordHasControlledAccess(rec copyRecord, resource string) bool { + if rec.ControlledAccess == nil { + return false + } + for _, candidate := range *rec.ControlledAccess { + if strings.TrimSpace(candidate) == resource { + return true + } + } + return false +} diff --git a/cmd/credentialhelper/main.go b/cmd/credentialhelper/main.go index f9324a64..e377c8f3 100644 --- a/cmd/credentialhelper/main.go +++ b/cmd/credentialhelper/main.go @@ -8,11 +8,12 @@ import ( "net/url" "os" "strings" + "time" - "github.com/calypr/data-client/credentials" + conf "github.com/calypr/calypr-cli/conf" + "github.com/calypr/calypr-cli/credentials" "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/gitrepo" - conf "github.com/calypr/syfon/client/config" "github.com/spf13/cobra" ) @@ -54,8 +55,12 @@ var Cmd = &cobra.Command{ // Local/basic-auth remotes: prefer explicit repo credentials. if username, password, err := gitrepo.GetRemoteBasicAuth(remoteName); err == nil && username != "" && password != "" { - fmt.Fprintf(os.Stdout, "username=%s\npassword=%s\n\n", username, password) - _ = gitrepo.SetRemoteLFSURL(remoteName, endpoint) + if _, err := fmt.Fprintf(os.Stdout, "username=%s\npassword=%s\n\n", username, password); err != nil { + return fmt.Errorf("write credential helper response: %w", err) + } + if err := gitrepo.SetRemoteLFSURL(remoteName, endpoint); err != nil { + return fmt.Errorf("sync LFS URL for remote %q: %w", remoteName, err) + } return nil } @@ -72,12 +77,14 @@ var Cmd = &cobra.Command{ if token != "" { cred.AccessToken = token } - if ensureErr := credentials.EnsureValidCredential(context.Background(), cred, logg); ensureErr == nil { - _ = manager.Save(cred) - token = strings.TrimSpace(cred.AccessToken) - if token != "" { - _ = gitrepo.SetRemoteToken(remoteName, token) + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + ensureErr := credentials.EnsureValidCredential(ctx, cred, logg) + cancel() + if ensureErr == nil { + if err := manager.Save(cred); err != nil { + return fmt.Errorf("save refreshed credential for remote %q: %w", remoteName, err) } + token = strings.TrimSpace(cred.AccessToken) } } @@ -86,10 +93,14 @@ var Cmd = &cobra.Command{ } // Username can be arbitrary for token-based Basic auth; server reads password token. - fmt.Fprintf(os.Stdout, "username=oauth2\npassword=%s\n\n", token) + if _, err := fmt.Fprintf(os.Stdout, "username=oauth2\npassword=%s\n\n", token); err != nil { + return fmt.Errorf("write credential helper response: %w", err) + } // Keep lfsurl synced for this remote. - _ = gitrepo.SetRemoteLFSURL(remoteName, endpoint) + if err := gitrepo.SetRemoteLFSURL(remoteName, endpoint); err != nil { + return fmt.Errorf("sync LFS URL for remote %q: %w", remoteName, err) + } return nil }, } diff --git a/cmd/delete/main.go b/cmd/delete/main.go index 9fe38764..9932f1eb 100644 --- a/cmd/delete/main.go +++ b/cmd/delete/main.go @@ -1,14 +1,17 @@ package delete import ( + "bufio" "context" "fmt" + "io" "os" + "strings" - "github.com/calypr/git-drs/internal/common" "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" - "github.com/calypr/git-drs/internal/drsremote" + "github.com/calypr/git-drs/internal/lookup" + "github.com/calypr/git-drs/internal/remoteruntime" "github.com/calypr/syfon/client/hash" "github.com/spf13/cobra" ) @@ -18,6 +21,8 @@ var ( confirmFlag bool ) +const confirmYes = "yes" + // Cmd line declaration // Cmd line declaration var Cmd = &cobra.Command{ @@ -46,14 +51,14 @@ var Cmd = &cobra.Command{ return fmt.Errorf("error getting default remote: %v", err) } - drsClient, err := cfg.GetRemoteClient(remoteName, logger) + drsClient, err := remoteruntime.New(cfg, remoteName, logger) if err != nil { logger.Error(fmt.Sprintf("error creating DRS client: %s", err)) return err } // Get record details before deletion for confirmation - records, err := drsremote.ObjectsByHashForScope(context.Background(), drsClient, oid) + records, err := lookup.ObjectsByHashForScope(context.Background(), drsClient, oid) if err != nil { return fmt.Errorf("error getting records for OID %s: %v", oid, err) } @@ -64,22 +69,40 @@ var Cmd = &cobra.Command{ // Show details and get confirmation unless --confirm flag is set if !confirmFlag { projectId := drsClient.ProjectId - common.DisplayWarningHeader(os.Stderr, "DELETE a DRS record") - common.DisplayField(os.Stderr, "Remote", string(remoteName)) - common.DisplayField(os.Stderr, "Project", projectId) - common.DisplayField(os.Stderr, "OID", oid) - common.DisplayField(os.Stderr, "Hash Type", hashType) - common.DisplayField(os.Stderr, "Matched DIDs", fmt.Sprintf("%d", len(records))) + if err := displayWarningHeader(os.Stderr, "DELETE a DRS record"); err != nil { + return err + } + if err := displayField(os.Stderr, "Remote", string(remoteName)); err != nil { + return err + } + if err := displayField(os.Stderr, "Project", projectId); err != nil { + return err + } + if err := displayField(os.Stderr, "OID", oid); err != nil { + return err + } + if err := displayField(os.Stderr, "Hash Type", hashType); err != nil { + return err + } + if err := displayField(os.Stderr, "Matched DIDs", fmt.Sprintf("%d", len(records))); err != nil { + return err + } if len(records) > 0 { - common.DisplayField(os.Stderr, "Example DID", records[0].Id) + if err := displayField(os.Stderr, "Example DID", records[0].Id); err != nil { + return err + } + } + if err := displayField(os.Stderr, "Warning", "This deletes all DIDs (pointers) resolved by this SHA256 in this backend"); err != nil { + return err + } + if err := displayFooter(os.Stderr); err != nil { + return err } - common.DisplayField(os.Stderr, "Warning", "This deletes all DIDs (pointers) resolved by this SHA256 in this backend") - common.DisplayFooter(os.Stderr) - if err := common.PromptForConfirmation( + if err := promptForConfirmation( os.Stderr, "Type 'yes' to confirm deletion", - common.ConfirmationYes, + confirmYes, false, ); err != nil { return err @@ -101,3 +124,42 @@ func init() { Cmd.Flags().StringVarP(&remote, "remote", "r", "", "target remote DRS server (default: default_remote)") Cmd.Flags().BoolVar(&confirmFlag, "confirm", false, "skip interactive confirmation prompt") } + +func promptForConfirmation(w io.Writer, prompt string, expectedResponse string, caseSensitive bool) error { + if _, err := fmt.Fprintf(w, "%s: ", prompt); err != nil { + return err + } + + reader := bufio.NewReader(os.Stdin) + response, err := reader.ReadString('\n') + if err != nil { + return fmt.Errorf("error reading confirmation: %v", err) + } + + response = strings.TrimSpace(response) + if !caseSensitive { + response = strings.ToLower(response) + expectedResponse = strings.ToLower(expectedResponse) + } + + if response != expectedResponse { + return fmt.Errorf("operation cancelled: confirmation did not match") + } + + return nil +} + +func displayWarningHeader(w io.Writer, operation string) error { + _, err := fmt.Fprintf(w, "\nWARNING: You are about to %s\n\n", operation) + return err +} + +func displayField(w io.Writer, key, value string) error { + _, err := fmt.Fprintf(w, "%-11s %s\n", key+":", value) + return err +} + +func displayFooter(w io.Writer) error { + _, err := fmt.Fprintf(w, "\nThis action CANNOT be undone.\n\n") + return err +} diff --git a/cmd/deleteproject/main.go b/cmd/deleteproject/main.go index b35606ff..1a43675f 100644 --- a/cmd/deleteproject/main.go +++ b/cmd/deleteproject/main.go @@ -1,13 +1,16 @@ package deleteproject import ( + "bufio" "context" "fmt" + "io" "os" + "strings" - "github.com/calypr/git-drs/internal/common" "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" + "github.com/calypr/git-drs/internal/remoteruntime" syservices "github.com/calypr/syfon/client/services" "github.com/spf13/cobra" ) @@ -38,7 +41,7 @@ var Cmd = &cobra.Command{ return fmt.Errorf("error getting default remote: %v", err) } - drsClient, err := cfg.GetRemoteClient(remoteName, logger) + drsClient, err := remoteruntime.New(cfg, remoteName, logger) if err != nil { logger.Error(fmt.Sprintf("error creating DRS client: %s", err)) return err @@ -66,28 +69,42 @@ var Cmd = &cobra.Command{ return fmt.Errorf("error: --confirm value '%s' does not match project ID '%s'", confirmFlag, projectId) } if confirmFlag != projectId { - common.DisplayWarningHeader(os.Stderr, "DELETE ALL RECORDS for a project") - common.DisplayField(os.Stderr, "Remote", string(remoteName)) - common.DisplayField(os.Stderr, "Project ID", projectId) + if err := displayWarningHeader(os.Stderr, "DELETE ALL RECORDS for a project"); err != nil { + return err + } + if err := displayField(os.Stderr, "Remote", string(remoteName)); err != nil { + return err + } + if err := displayField(os.Stderr, "Project ID", projectId); err != nil { + return err + } if listResp.Records != nil && len(*listResp.Records) > 0 { sample := (*listResp.Records)[0] fmt.Fprintf(os.Stderr, "\nSample record from this project:\n") - common.DisplayField(os.Stderr, " DID", sample.Did) + if err := displayField(os.Stderr, " DID", sample.Did); err != nil { + return err + } if sample.FileName != nil && *sample.FileName != "" { - common.DisplayField(os.Stderr, " Filename", *sample.FileName) + if err := displayField(os.Stderr, " Filename", *sample.FileName); err != nil { + return err + } } if sample.Size != nil { - common.DisplayField(os.Stderr, " Size", fmt.Sprintf("%d bytes", *sample.Size)) + if err := displayField(os.Stderr, " Size", fmt.Sprintf("%d bytes", *sample.Size)); err != nil { + return err + } } } else { fmt.Fprintf(os.Stderr, "\nNo records found for this project.\n") } fmt.Fprintf(os.Stderr, "\nThis will DELETE ALL records in project '%s'.\n", projectId) - common.DisplayFooter(os.Stderr) + if err := displayFooter(os.Stderr); err != nil { + return err + } - if err := common.PromptForConfirmation(os.Stderr, fmt.Sprintf("Type the project ID '%s' to confirm deletion", projectId), projectId, true); err != nil { + if err := promptForConfirmation(os.Stderr, fmt.Sprintf("Type the project ID '%s' to confirm deletion", projectId), projectId, true); err != nil { return err } } @@ -110,3 +127,42 @@ func init() { Cmd.Flags().StringVarP(&remote, "remote", "r", "", "target remote DRS server (default: default_remote)") Cmd.Flags().StringVar(&confirmFlag, "confirm", "", "skip interactive confirmation by providing the project_id (e.g., --confirm my-project)") } + +func promptForConfirmation(w io.Writer, prompt string, expectedResponse string, caseSensitive bool) error { + if _, err := fmt.Fprintf(w, "%s: ", prompt); err != nil { + return err + } + + reader := bufio.NewReader(os.Stdin) + response, err := reader.ReadString('\n') + if err != nil { + return fmt.Errorf("error reading confirmation: %v", err) + } + + response = strings.TrimSpace(response) + if !caseSensitive { + response = strings.ToLower(response) + expectedResponse = strings.ToLower(expectedResponse) + } + + if response != expectedResponse { + return fmt.Errorf("operation cancelled: confirmation did not match") + } + + return nil +} + +func displayWarningHeader(w io.Writer, operation string) error { + _, err := fmt.Fprintf(w, "\nWARNING: You are about to %s\n\n", operation) + return err +} + +func displayField(w io.Writer, key, value string) error { + _, err := fmt.Fprintf(w, "%-11s %s\n", key+":", value) + return err +} + +func displayFooter(w io.Writer) error { + _, err := fmt.Fprintf(w, "\nThis action CANNOT be undone.\n\n") + return err +} diff --git a/cmd/download/main.go b/cmd/download/main.go deleted file mode 100644 index 599cd2e0..00000000 --- a/cmd/download/main.go +++ /dev/null @@ -1,100 +0,0 @@ -package download - -import ( - "context" - "fmt" - "path/filepath" - "strings" - - "github.com/calypr/git-drs/internal/common" - "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/drslog" - "github.com/calypr/git-drs/internal/drsremote" - drsapi "github.com/calypr/syfon/apigen/client/drs" - sydownload "github.com/calypr/syfon/client/transfer/download" - "github.com/spf13/cobra" -) - -var remote string -var outdir string - -// Cmd line declaration -var Cmd = &cobra.Command{ - Use: "download ", - Short: "Download a file from a DRS server", - Long: "Download a file from a DRS server, without creating an LFS pointer", - Args: cobra.MinimumNArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - - logger := drslog.GetLogger() - - config, err := config.LoadConfig() - if err != nil { - return err - } - - remoteName, err := config.GetRemoteOrDefault(remote) - if err != nil { - logger.Error(fmt.Sprintf("Error getting remote: %v", err)) - return err - } - - client, err := config.GetRemoteClient(remoteName, logger) - if err != nil { - return err - } - for _, src := range args { - obj, err := client.Client.DRS().GetObject(context.Background(), src) - if err != nil { - logger.Error(fmt.Sprintf("Error downloading object %s: %v", src, err)) - } else { - common.PrintDRSObject(obj, false) - dstName := src - if obj.Name != nil && *obj.Name != "" { - dstName = filepath.Base(*obj.Name) - } - dstPath := filepath.Join(outdir, dstName) - logger.Info(fmt.Sprintf("Downloading object %s to path %s", src, dstPath)) - accessURL, err := resolveAccessURL(cmd.Context(), client, obj) - if err != nil { - logger.Error(fmt.Sprintf("Error resolving access URL for object %s: %v", src, err)) - continue - } - if err := drsremote.DownloadResolvedToPath(cmd.Context(), client, obj.Id, dstPath, &obj, accessURL, sydownload.DownloadOptions{ - MultipartThreshold: 5 * 1024 * 1024, - Concurrency: 2, - ChunkSize: 64 * 1024 * 1024, - }); err != nil { - logger.Error(fmt.Sprintf("Error downloading object %s to path %s: %v", src, dstPath, err)) - } else { - logger.Info(fmt.Sprintf("Successfully downloaded object %s to path %s", src, dstPath)) - } - } - } - - return nil - }, -} - -func init() { - Cmd.Flags().StringVarP(&remote, "remote", "r", "", "target remote DRS server (default: default_remote)") - Cmd.Flags().StringVarP(&outdir, "outdir", "o", ".", "output directory for downloaded files") -} - -func resolveAccessURL(ctx context.Context, remote *config.GitContext, obj drsapi.DrsObject) (*drsapi.AccessURL, error) { - if remote == nil || remote.Client == nil { - return nil, fmt.Errorf("DRS client unavailable") - } - if obj.AccessMethods == nil || len(*obj.AccessMethods) == 0 { - return nil, fmt.Errorf("no access methods available for DRS object %s", obj.Id) - } - accessType := strings.TrimSpace(string((*obj.AccessMethods)[0].Type)) - if accessType == "" { - return nil, fmt.Errorf("no access type found in access method for DRS object %s", obj.Id) - } - accessURL, err := remote.Client.DRS().GetAccessURL(ctx, obj.Id, accessType) - if err != nil { - return nil, err - } - return &accessURL, nil -} diff --git a/cmd/fetch/fetch_test.go b/cmd/fetch/fetch_test.go deleted file mode 100644 index 37766718..00000000 --- a/cmd/fetch/fetch_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package fetch - -import ( - "testing" - - "github.com/calypr/git-drs/internal/testutils" - "github.com/stretchr/testify/assert" -) - -func TestFetchCmdArgs(t *testing.T) { - // Test with no arguments (valid) - err := Cmd.Args(Cmd, []string{}) - assert.NoError(t, err) - - // Test with 1 argument (valid) - err = Cmd.Args(Cmd, []string{"origin"}) - assert.NoError(t, err) - - // Test with multiple arguments (invalid) - err = Cmd.Args(Cmd, []string{"origin", "extra"}) - assert.Error(t, err) -} - -func TestFetchRun_Error(t *testing.T) { - _ = testutils.SetupTestGitRepo(t) - // No config, should error - err := Cmd.RunE(Cmd, []string{}) - assert.Error(t, err) -} - -func TestFetchRun_InvalidRemote(t *testing.T) { - tmpDir := testutils.SetupTestGitRepo(t) - testutils.CreateDefaultTestConfig(t, tmpDir) - // Fetch from non-existent remote - err := Cmd.RunE(Cmd, []string{"no-remote"}) - assert.Error(t, err) -} diff --git a/cmd/fetch/main.go b/cmd/fetch/main.go deleted file mode 100644 index 0acf089a..00000000 --- a/cmd/fetch/main.go +++ /dev/null @@ -1,66 +0,0 @@ -package fetch - -import ( - "fmt" - "os/exec" - "strings" - - "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/drslog" - "github.com/spf13/cobra" -) - -var runCommand = func(name string, args ...string) ([]byte, error) { - cmd := exec.Command(name, args...) - return cmd.CombinedOutput() -} - -// Cmd line declaration -var Cmd = &cobra.Command{ - Use: "fetch [remote-name]", - Short: "Fetch LFS objects from remote via standard git-lfs", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) > 1 { - cmd.SilenceUsage = false - return fmt.Errorf("error: accepts at most 1 argument (remote name), received %d\n\nUsage: %s\n\nSee 'git drs fetch --help' for more details", len(args), cmd.UseLine()) - } - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - logger := drslog.GetLogger() - - cfg, err := config.LoadConfig() - if err != nil { - return fmt.Errorf("error loading config: %v", err) - } - - var remote config.Remote - if len(args) > 0 { - remote = config.Remote(args[0]) - } else { - remote, err = cfg.GetDefaultRemote() - if err != nil { - logger.Error(fmt.Sprintf("Error getting remote: %v", err)) - return err - } - } - - drsClient, err := cfg.GetRemoteClient(remote, logger) - if err != nil { - logger.Error(fmt.Sprintf("\nerror creating DRS client: %s", err)) - return err - } - _ = drsClient // Remote validation only. - - out, err := runCommand("git", "lfs", "pull", string(remote)) - if err != nil { - msg := strings.TrimSpace(string(out)) - if msg == "" { - msg = err.Error() - } - return fmt.Errorf("git lfs pull failed for remote %q: %s", remote, msg) - } - - return nil - }, -} diff --git a/cmd/filter/main.go b/cmd/filter/main.go index 7f5ebd04..cc70f45b 100644 --- a/cmd/filter/main.go +++ b/cmd/filter/main.go @@ -19,11 +19,11 @@ import ( "os" "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/drsfilter" "github.com/calypr/git-drs/internal/drslog" - "github.com/calypr/git-drs/internal/drsremote" - "github.com/calypr/git-drs/internal/gitfilter" + internalfilter "github.com/calypr/git-drs/internal/filter" "github.com/calypr/git-drs/internal/lfs" + "github.com/calypr/git-drs/internal/remoteruntime" + internaltransfer "github.com/calypr/git-drs/internal/transfer" "github.com/spf13/cobra" ) @@ -51,13 +51,13 @@ func runFilter(cmd *cobra.Command, _ []string) error { return fmt.Errorf("filter: load config: %w", err) } - var drsCtx *config.GitContext + var drsCtx *remoteruntime.GitContext remote, err := cfg.GetDefaultRemote() if err != nil { logger.Info("filter: no default remote", "err", err) } else { - drsCtx, err = cfg.GetRemoteClient(remote, logger) + drsCtx, err = remoteruntime.New(cfg, remote, logger) if err != nil { logger.Info("DRS server not configured or unreachable", "err", err) } @@ -69,7 +69,7 @@ func runFilter(cmd *cobra.Command, _ []string) error { } logger.Debug("Resolved LFS root directory", "lfsRoot", lfsRoot) // Build the filter and register handlers. - f := gitfilter.NewGitFilter(os.Stdin, os.Stdout, logger). + f := internalfilter.NewGitFilter(os.Stdin, os.Stdout, logger). OnSmudge(makeSmudgeHandler(drsCtx, logger)). OnClean(makeCleanHandler(lfsRoot, logger)) @@ -80,16 +80,16 @@ func runFilter(cmd *cobra.Command, _ []string) error { // Smudge handler — checkout: LFS pointer → real file content // -------------------------------------------------------------------------- -func makeSmudgeHandler(drsCtx *config.GitContext, logger *slog.Logger) gitfilter.SmudgeFunc { - return func(ctx context.Context, req gitfilter.FilterRequest, ptr io.Reader, dst io.Writer) error { +func makeSmudgeHandler(drsCtx *remoteruntime.GitContext, logger *slog.Logger) internalfilter.SmudgeFunc { + return func(ctx context.Context, req internalfilter.FilterRequest, ptr io.Reader, dst io.Writer) error { logger.Debug("smudge handler invoked", "pathname", req.Pathname) - var downloadFn drsfilter.SmudgeDownloadFunc - if drsCtx != nil { + var downloadFn internalfilter.SmudgeDownloadFunc + if drsCtx != nil && !internalfilter.ShouldSkipSmudge() { downloadFn = func(callCtx context.Context, oid, cachePath string) error { - return drsremote.DownloadToCachePath(callCtx, drsCtx, logger, oid, cachePath) + return internaltransfer.DownloadToCachePath(callCtx, drsCtx, oid, cachePath) } } - return drsfilter.SmudgeContent(ctx, req.Pathname, ptr, dst, logger, downloadFn) + return internalfilter.SmudgeContent(ctx, req.Pathname, ptr, dst, logger, downloadFn) } } @@ -97,10 +97,10 @@ func makeSmudgeHandler(drsCtx *config.GitContext, logger *slog.Logger) gitfilter // Clean handler — stage: real file content → LFS pointer // -------------------------------------------------------------------------- -func makeCleanHandler(lfsRoot string, logger *slog.Logger) gitfilter.CleanFunc { - return func(ctx context.Context, req gitfilter.FilterRequest, content io.Reader, dst io.Writer) error { +func makeCleanHandler(lfsRoot string, logger *slog.Logger) internalfilter.CleanFunc { + return func(ctx context.Context, req internalfilter.FilterRequest, content io.Reader, dst io.Writer) error { logger.Debug("clean", "pathname", req.Pathname) - return drsfilter.CleanContent(ctx, lfsRoot, req.Pathname, content, dst, logger) + return internalfilter.CleanContent(ctx, lfsRoot, req.Pathname, content, dst, logger) } } diff --git a/cmd/initialize/main.go b/cmd/initialize/main.go index 95fcdf65..304609c2 100644 --- a/cmd/initialize/main.go +++ b/cmd/initialize/main.go @@ -10,7 +10,6 @@ import ( "strings" "time" - "github.com/calypr/git-drs/internal/common" "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/gitrepo" @@ -39,73 +38,152 @@ var Cmd = &cobra.Command{ }, RunE: func(cmd *cobra.Command, args []string) error { logg := drslog.GetLogger() - - // check if .git dir exists to ensure you're in a git repository - _, err := gitrepo.GitTopLevel() - if err != nil { - return fmt.Errorf("error: not in a git repository. Please run this command in the root of your git repository") + if err := InitializeRepo(logg); err != nil { + return err } + logg.Debug(fmt.Sprintf("Using %d concurrent transfers", transfers)) + return nil + }, +} - // create config file if it doesn't exist - err = config.CreateEmptyConfig() - if err != nil { - return fmt.Errorf("error: unable to create config file: %v", err) - } +// InitializeRepo applies git-drs repository-local setup to the current git repository. +// It is safe to call repeatedly. +func InitializeRepo(logg *slog.Logger) error { + // check if .git dir exists to ensure you're in a git repository + _, err := gitrepo.GitTopLevel() + if err != nil { + return fmt.Errorf("error: not in a git repository. Please run this command in the root of your git repository") + } - // load the config - _, err = config.LoadConfig() - if err != nil { - logg.Debug(fmt.Sprintf("We should probably fix this: %v", err)) - return fmt.Errorf("error: unable to load config file: %v", err) - } + // create config file if it doesn't exist + err = config.CreateEmptyConfig() + if err != nil { + return fmt.Errorf("error: unable to create config file: %v", err) + } - // create drs directories - drsDir := common.DRS_DIR - drsLfsObjsDir := common.DRS_OBJS_PATH - if err := os.MkdirAll(drsDir, 0755); err != nil { - return fmt.Errorf("error: unable to create drs directory: %v", err) - } - if err := os.MkdirAll(drsLfsObjsDir, 0755); err != nil { - return fmt.Errorf("error: unable to create drs lfs objects directory: %v", err) - } + // load the config + _, err = config.LoadConfig() + if err != nil { + logg.Debug(fmt.Sprintf("We should probably fix this: %v", err)) + return fmt.Errorf("error: unable to load config file: %v", err) + } - err = initGitConfig() - if err != nil { - return fmt.Errorf("error initializing git-drs repository config: %v", err) - } + // create drs directories + drsDir := gitrepo.DRSDir + drsLfsObjsDir := gitrepo.DRSObjectsPath + if err := os.MkdirAll(drsDir, 0755); err != nil { + return fmt.Errorf("error: unable to create drs directory: %v", err) + } + if err := os.MkdirAll(drsLfsObjsDir, 0755); err != nil { + return fmt.Errorf("error: unable to create drs lfs objects directory: %v", err) + } - // install pre-push hook - err = installPrePushHook(logg) - if err != nil { - return fmt.Errorf("error installing pre-push hook: %v", err) - } - // install pre-commit hook - err = installPreCommitHook(logg) - if err != nil { - return fmt.Errorf("error installing pre-commit hook: %v", err) + err = initGitConfig() + if err != nil { + return fmt.Errorf("error initializing git-drs repository config: %v", err) + } + + // install pre-commit hook + err = installPreCommitHook(logg) + if err != nil { + return fmt.Errorf("error installing pre-commit hook: %v", err) + } + if err := removeLegacyPrePushHook(logg); err != nil { + return fmt.Errorf("error repairing legacy pre-push hook: %v", err) + } + + logg.Debug("Git DRS initialized") + return nil +} + +// EnsureInitialized applies initialization only when the repository does not +// already appear to have git-drs local setup installed. +func EnsureInitialized(logg *slog.Logger) error { + initialized, err := isInitialized() + if err != nil { + return err + } + if initialized { + return removeLegacyPrePushHook(logg) + } + return InitializeRepo(logg) +} + +func isInitialized() (bool, error) { + if _, err := gitrepo.GitTopLevel(); err != nil { + return false, fmt.Errorf("error: not in a git repository. Please run this command in the root of your git repository") + } + + if _, err := os.Stat(gitrepo.DRSDir); err != nil { + if os.IsNotExist(err) { + return false, nil } + return false, fmt.Errorf("error checking git-drs directory: %v", err) + } - // final logs - logg.Debug("Git DRS initialized") - logg.Debug(fmt.Sprintf("Using %d concurrent transfers", transfers)) - return nil - }, + if val, err := gitrepo.GetGitConfigString("filter.drs.process"); err != nil || strings.TrimSpace(val) != "git-drs filter" { + return false, err + } + if val, err := gitrepo.GetGitConfigString("filter.drs.clean"); err != nil || strings.TrimSpace(val) != "git-drs clean -- %f" { + return false, err + } + if val, err := gitrepo.GetGitConfigString("filter.drs.smudge"); err != nil || strings.TrimSpace(val) != "git-drs smudge -- %f" { + return false, err + } + if val, err := gitrepo.GetGitConfigString("filter.drs.required"); err != nil || strings.TrimSpace(val) != "true" { + return false, err + } + + preCommitInstalled, err := hookContains("pre-commit", "git drs precommit") + if err != nil { + return false, err + } + if !preCommitInstalled { + return false, nil + } + + return true, nil +} + +func hookContains(name, marker string) (bool, error) { + hooksDir, err := gitrepo.GetGitHooksDir() + if err != nil { + return false, fmt.Errorf("unable to get hooks directory: %w", err) + } + content, err := os.ReadFile(filepath.Join(hooksDir, name)) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + return strings.Contains(string(content), marker), nil } +var noSkipSmudge bool + func initGitConfig() error { configs := map[string]string{ + "push.autoSetupRemote": "true", "lfs.allowincompletepush": "false", "lfs.concurrenttransfers": strconv.Itoa(transfers), // Use git-drs as the long-running filter-process handler. // This replaces the default git-lfs smudge/clean per-invocation commands // with a single persistent process that calls the DRS transfer stack directly. - "filter.drs.process": "git-drs filter", + "filter.drs.clean": "git-drs clean -- %f", + "filter.drs.smudge": "git-drs smudge -- %f", + "filter.drs.process": "git-drs filter", + "filter.drs.required": "true", // Canonical git-drs config keys consumed by clients. "drs.upsert": strconv.FormatBool(upsert), "drs.multipart-threshold": strconv.Itoa(multiPartThreshold), "drs.enable-data-client-logs": strconv.FormatBool(enableDataClientLogs), } + if noSkipSmudge { + configs["drs.skipsmudge"] = "false" + } + if err := gitrepo.SetGitConfigOptions(configs); err != nil { return fmt.Errorf("unable to write git config: %w", err) } @@ -117,62 +195,7 @@ func init() { Cmd.Flags().BoolVarP(&upsert, "upsert", "u", false, "Enable upsert for DRS objects") Cmd.Flags().IntVarP(&multiPartThreshold, "multipart-threshold", "m", 5120, "Multipart threshold in MB") Cmd.Flags().BoolVar(&enableDataClientLogs, "enable-data-client-logs", false, "Enable data-client internal logs") -} - -func installPrePushHook(logger *slog.Logger) error { - hooksDir, err := gitrepo.GetGitHooksDir() - if err != nil { - return fmt.Errorf("unable to get hooks directory: %w", err) - } - - if err := os.MkdirAll(hooksDir, 0755); err != nil { - return fmt.Errorf("unable to create hooks directory: %w", err) - } - - hookPath := filepath.Join(hooksDir, "pre-push") - hookBody := ` -# . git/hooks/pre-push -remote="$1" -url="$2" - -# Buffer stdin for both commands -TMPFILE="${TMPDIR:-/tmp}/git-drs-$$" -trap "rm -f $TMPFILE" EXIT -cat > "$TMPFILE" - -# Run DRS preparation -git drs pre-push-prepare "$remote" "$url" < "$TMPFILE" || exit 1 - -# The managed git-drs push command handles upload/register directly. -# The hook only stages metadata before the Git push proceeds. -` - hookScript := "#!/bin/sh\n" + hookBody - - existingContent, err := os.ReadFile(hookPath) - if err == nil { - // there is an existing hook, rename it, and let the user know - // Backup existing hook with timestamp - timestamp := time.Now().Format("20060102T150405") - backupPath := hookPath + "." + timestamp - if err := os.WriteFile(backupPath, existingContent, 0644); err != nil { - return fmt.Errorf("unable to back up existing pre-push hook: %w", err) - } - if err := os.Remove(hookPath); err != nil { - return fmt.Errorf("unable to remove hook after backing up: %w", err) - } - logger.Debug(fmt.Sprintf("pre-push hook updated; backup written to %s", backupPath)) - } - // If there was an error other than expected not existing, return it - if err != nil && !os.IsNotExist(err) { - return fmt.Errorf("unable to read pre-push hook: %w", err) - } - - err = os.WriteFile(hookPath, []byte(hookScript), 0755) - if err != nil { - return fmt.Errorf("unable to write pre-push hook: %w", err) - } - logger.Debug("pre-push hook installed") - return nil + Cmd.Flags().BoolVar(&noSkipSmudge, "no-skip-smudge", false, "Disable skipping smudge filter (force downloading file contents during checkout)") } func installPreCommitHook(logger *slog.Logger) error { @@ -220,3 +243,32 @@ exec git drs precommit logger.Debug("pre-commit hook installed") return nil } + +func removeLegacyPrePushHook(logger *slog.Logger) error { + hooksDir, err := gitrepo.GetGitHooksDir() + if err != nil { + return fmt.Errorf("unable to get hooks directory: %w", err) + } + hookPath := filepath.Join(hooksDir, "pre-push") + content, err := os.ReadFile(hookPath) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("unable to read pre-push hook: %w", err) + } + if !strings.Contains(string(content), "git drs pre-push-prepare") { + return nil + } + + timestamp := time.Now().Format("20060102T150405") + backupPath := hookPath + "." + timestamp + if err := os.WriteFile(backupPath, content, 0o644); err != nil { + return fmt.Errorf("unable to back up legacy pre-push hook: %w", err) + } + if err := os.Remove(hookPath); err != nil { + return fmt.Errorf("unable to remove legacy pre-push hook: %w", err) + } + logger.Debug(fmt.Sprintf("legacy pre-push hook removed; backup written to %s", backupPath)) + return nil +} diff --git a/cmd/initialize/main_test.go b/cmd/initialize/main_test.go index 1126a2dd..65e98062 100644 --- a/cmd/initialize/main_test.go +++ b/cmd/initialize/main_test.go @@ -11,28 +11,6 @@ import ( "github.com/calypr/git-drs/internal/testutils" ) -func TestInstallPrePushHook(t *testing.T) { - testutils.SetupTestGitRepo(t) - logger := drslog.NewNoOpLogger() - - if err := installPrePushHook(logger); err != nil { - t.Fatalf("installPrePushHook error: %v", err) - } - - hookPath := filepath.Join(".git", "hooks", "pre-push") - content, err := os.ReadFile(hookPath) - if err != nil { - t.Fatalf("read hook: %v", err) - } - if !strings.Contains(string(content), "git drs pre-push") { - t.Fatalf("expected hook to contain git drs pre-push") - } - - if err := installPrePushHook(logger); err != nil { - t.Fatalf("installPrePushHook second call error: %v", err) - } -} - func TestInstallPreCommitHook(t *testing.T) { testutils.SetupTestGitRepo(t) logger := drslog.NewNoOpLogger() @@ -105,4 +83,66 @@ func TestInitConfigValues(t *testing.T) { check("lfs.concurrenttransfers", "8") check("lfs.allowincompletepush", "false") + check("push.autoSetupRemote", "true") + check("filter.drs.clean", "git-drs clean -- %f") + check("filter.drs.smudge", "git-drs smudge -- %f") + check("filter.drs.process", "git-drs filter") + check("filter.drs.required", "true") +} + +func TestEnsureInitialized(t *testing.T) { + testutils.SetupTestGitRepo(t) + logger := drslog.NewNoOpLogger() + + if err := EnsureInitialized(logger); err != nil { + t.Fatalf("EnsureInitialized error: %v", err) + } + if err := EnsureInitialized(logger); err != nil { + t.Fatalf("EnsureInitialized second call error: %v", err) + } + + if _, err := os.Stat(gitrepo.DRSDir); err != nil { + t.Fatalf("expected %s to exist: %v", gitrepo.DRSDir, err) + } + filterProcess, err := gitrepo.GetGitConfigString("filter.drs.process") + if err != nil { + t.Fatalf("GetGitConfigString(filter.drs.process): %v", err) + } + if filterProcess != "git-drs filter" { + t.Fatalf("unexpected filter.drs.process: %q", filterProcess) + } + filterClean, err := gitrepo.GetGitConfigString("filter.drs.clean") + if err != nil { + t.Fatalf("GetGitConfigString(filter.drs.clean): %v", err) + } + if filterClean != "git-drs clean -- %f" { + t.Fatalf("unexpected filter.drs.clean: %q", filterClean) + } +} + +func TestEnsureInitializedRemovesLegacyPrePushHook(t *testing.T) { + testutils.SetupTestGitRepo(t) + logger := drslog.NewNoOpLogger() + + hookPath := filepath.Join(".git", "hooks", "pre-push") + legacyHook := "#!/bin/sh\nexec git drs pre-push-prepare\n" + if err := os.WriteFile(hookPath, []byte(legacyHook), 0o755); err != nil { + t.Fatalf("write legacy pre-push hook: %v", err) + } + + if err := EnsureInitialized(logger); err != nil { + t.Fatalf("EnsureInitialized error: %v", err) + } + + if _, err := os.Stat(hookPath); !os.IsNotExist(err) { + t.Fatalf("expected legacy pre-push hook to be removed, got err=%v", err) + } + + matches, err := filepath.Glob(hookPath + ".*") + if err != nil { + t.Fatalf("glob hook backups: %v", err) + } + if len(matches) == 0 { + t.Fatal("expected legacy pre-push hook backup to be created") + } } diff --git a/cmd/list/main.go b/cmd/list/main.go deleted file mode 100644 index dfcf7a7c..00000000 --- a/cmd/list/main.go +++ /dev/null @@ -1,59 +0,0 @@ -package list - -import ( - "context" - "fmt" - - "github.com/calypr/git-drs/internal/common" - "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/drslog" - "github.com/spf13/cobra" -) - -var remote string -var pretty = false - -// Cmd line declaration -var Cmd = &cobra.Command{ - Use: "list", - Short: "List DRS objects in a DRS server", - Long: "List DRS objects in a DRS server", - RunE: func(cmd *cobra.Command, args []string) error { - - logger := drslog.GetLogger() - - config, err := config.LoadConfig() - if err != nil { - return err - } - - remoteName, err := config.GetRemoteOrDefault(remote) - if err != nil { - logger.Error(fmt.Sprintf("Error getting remote: %v", err)) - return err - } - - client, err := config.GetRemoteClient(remoteName, logger) - if err != nil { - return err - } - - objs, err := client.Client.DRS().ListObjects(context.Background(), 1000, 1) - if err != nil { - return err - } - - for _, drsObj := range objs.DrsObjects { - if err := common.PrintDRSObject(drsObj, pretty); err != nil { - return err - } - } - - return nil - }, -} - -func init() { - Cmd.Flags().StringVarP(&remote, "remote", "r", "", "target remote DRS server (default: default_remote)") - Cmd.Flags().BoolVarP(&pretty, "pretty", "p", false, "pretty print JSON output") -} diff --git a/cmd/lsfiles/main.go b/cmd/lsfiles/main.go index 96c3dfd1..e796e599 100644 --- a/cmd/lsfiles/main.go +++ b/cmd/lsfiles/main.go @@ -1,117 +1,41 @@ package lsfiles import ( + "context" "fmt" - "log/slog" - "sort" - "strings" - "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/drslog" - "github.com/calypr/git-drs/internal/drsremote" - "github.com/calypr/git-drs/internal/lfs" "github.com/spf13/cobra" ) var gitRemote string var drsRemote string +var includePatterns []string +var showLong bool +var nameOnly bool +var jsonOutput bool +var drsStatus bool -var ( - loadConfig = config.LoadConfig - resolveRemote = func(cfg *config.Config, name string) (config.Remote, error) { return cfg.GetRemoteOrDefault(name) } - newRemoteClient = func(cfg *config.Config, remote config.Remote, logger *slog.Logger) (*config.GitContext, error) { - return cfg.GetRemoteClient(remote, logger) +func validateOutputFlags() error { + if nameOnly && jsonOutput { + return fmt.Errorf("--name-only and --json are mutually exclusive") } - loadLFSInventory = lfs.GetAllLfsFiles - lookupScopedObjects = drsremote.ObjectsByHashForScope -) - -type fileRow struct { - OID string - Status string - Path string - Detail string -} - -func collectRows(cmd *cobra.Command, gitRemoteName, drsRemoteName string) ([]fileRow, error) { - logger := drslog.GetLogger() - - cfg, err := loadConfig() - if err != nil { - return nil, err - } - - remoteName, err := resolveRemote(cfg, drsRemoteName) - if err != nil { - logger.Error(fmt.Sprintf("Error getting remote: %v", err)) - return nil, err - } - - client, err := newRemoteClient(cfg, remoteName, logger) - if err != nil { - return nil, err - } - - lfsFiles, err := loadLFSInventory(gitRemoteName, drsRemoteName, []string{}, logger) - if err != nil { - return nil, err - } - - keys := make([]string, 0, len(lfsFiles)) - for path := range lfsFiles { - keys = append(keys, path) - } - sort.Strings(keys) - - rows := make([]fileRow, 0, len(keys)) - for _, path := range keys { - info := lfsFiles[path] - row := fileRow{ - OID: info.Oid, - Path: path, - } - - results, err := lookupScopedObjects(cmd.Context(), client, info.Oid) - switch { - case err != nil: - row.Status = "error" - row.Detail = err.Error() - case len(results) == 0: - row.Status = "missing" - row.Detail = "-" - default: - row.Status = "present" - ids := make([]string, 0, len(results)) - for _, res := range results { - ids = append(ids, "drs://"+res.Id) - } - row.Detail = strings.Join(ids, ",") - } - - rows = append(rows, row) - } - - return rows, nil -} - -func printRows(cmd *cobra.Command, rows []fileRow) error { - if _, err := fmt.Fprintf(cmd.OutOrStdout(), "OID\tSTATUS\tPATH\tDETAIL\n"); err != nil { - return err - } - for _, row := range rows { - if _, err := fmt.Fprintf(cmd.OutOrStdout(), "%s\t%s\t%s\t%s\n", row.OID, row.Status, row.Path, row.Detail); err != nil { - return err - } + if showLong && nameOnly { + return fmt.Errorf("--long and --name-only are mutually exclusive") } return nil } -// Cmd line declaration var Cmd = &cobra.Command{ - Use: "ls-files", - Short: "List local LFS-tracked files and their DRS registration status", + Use: "ls-files [pathspec...]", + Short: "List tracked DRS/LFS pointer files in the repository", + Long: "List tracked DRS/Git-LFS pointer files in the repository. By default this behaves like a local file inventory. Use --drs to also resolve DRS registration status.", RunE: func(cmd *cobra.Command, args []string) error { - rows, err := collectRows(cmd, gitRemote, drsRemote) + if err := validateOutputFlags(); err != nil { + return err + } + patterns := append([]string{}, includePatterns...) + patterns = append(patterns, args...) + rows, err := collectRows(context.Background(), gitRemote, drsRemote, patterns, drsStatus) if err != nil { return err } @@ -122,4 +46,9 @@ var Cmd = &cobra.Command{ func init() { Cmd.Flags().StringVarP(&gitRemote, "git-remote", "r", "", "target remote Git server (default: origin)") Cmd.Flags().StringVarP(&drsRemote, "drs-remote", "d", "", "target remote DRS server (default: origin)") + Cmd.Flags().StringArrayVarP(&includePatterns, "include", "I", nil, "include pathspec/glob pattern(s)") + Cmd.Flags().BoolVarP(&showLong, "long", "l", false, "show full object IDs") + Cmd.Flags().BoolVarP(&nameOnly, "name-only", "n", false, "show only file paths") + Cmd.Flags().BoolVar(&jsonOutput, "json", false, "emit JSON output") + Cmd.Flags().BoolVar(&drsStatus, "drs", false, "include DRS registration lookup details") } diff --git a/cmd/lsfiles/main_test.go b/cmd/lsfiles/main_test.go index 492b0b4f..12d66154 100644 --- a/cmd/lsfiles/main_test.go +++ b/cmd/lsfiles/main_test.go @@ -1,94 +1,30 @@ package lsfiles -import ( - "bytes" - "context" - "errors" - "log/slog" - "strings" - "testing" +import "testing" - "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/lfs" - drsapi "github.com/calypr/syfon/apigen/client/drs" - "github.com/spf13/cobra" -) +func resetFlagsForTest() { + gitRemote = "" + drsRemote = "" + includePatterns = nil + showLong = false + nameOnly = false + jsonOutput = false + drsStatus = false +} -func TestCollectRowsAndPrintRows(t *testing.T) { - oldLoadConfig := loadConfig - oldResolveRemote := resolveRemote - oldNewRemoteClient := newRemoteClient - oldLoadLFSInventory := loadLFSInventory - oldLookupScopedObjects := lookupScopedObjects - t.Cleanup(func() { - loadConfig = oldLoadConfig - resolveRemote = oldResolveRemote - newRemoteClient = oldNewRemoteClient - loadLFSInventory = oldLoadLFSInventory - lookupScopedObjects = oldLookupScopedObjects - }) +func TestValidateOutputFlags(t *testing.T) { + resetFlagsForTest() - loadConfig = func() (*config.Config, error) { - return &config.Config{}, nil - } - resolveRemote = func(cfg *config.Config, name string) (config.Remote, error) { - return config.Remote("origin"), nil - } - newRemoteClient = func(cfg *config.Config, remote config.Remote, logger *slog.Logger) (*config.GitContext, error) { - return &config.GitContext{}, nil - } - loadLFSInventory = func(gitRemoteName, gitRemoteLocation string, branches []string, logger *slog.Logger) (map[string]lfs.LfsFileInfo, error) { - return map[string]lfs.LfsFileInfo{ - "b/file2.bin": {Name: "b/file2.bin", Oid: strings.Repeat("b", 64)}, - "a/file1.bin": {Name: "a/file1.bin", Oid: strings.Repeat("a", 64)}, - "c/file3.bin": {Name: "c/file3.bin", Oid: strings.Repeat("c", 64)}, - }, nil - } - lookupScopedObjects = func(ctx context.Context, drsCtx *config.GitContext, checksum string) ([]drsapi.DrsObject, error) { - switch checksum { - case strings.Repeat("a", 64): - return []drsapi.DrsObject{{Id: "did-1"}}, nil - case strings.Repeat("b", 64): - return nil, nil - default: - return nil, errors.New("lookup failed") - } + nameOnly = true + jsonOutput = true + if err := validateOutputFlags(); err == nil { + t.Fatal("expected name-only/json conflict") } - cmd := &cobra.Command{} - rows, err := collectRows(cmd, "", "") - if err != nil { - t.Fatalf("collectRows returned error: %v", err) - } - if len(rows) != 3 { - t.Fatalf("expected 3 rows, got %d", len(rows)) - } - if rows[0].Path != "a/file1.bin" || rows[0].Status != "present" || rows[0].Detail != "drs://did-1" { - t.Fatalf("unexpected first row: %+v", rows[0]) - } - if rows[1].Path != "b/file2.bin" || rows[1].Status != "missing" || rows[1].Detail != "-" { - t.Fatalf("unexpected second row: %+v", rows[1]) - } - if rows[2].Path != "c/file3.bin" || rows[2].Status != "error" || rows[2].Detail != "lookup failed" { - t.Fatalf("unexpected third row: %+v", rows[2]) - } - - var out bytes.Buffer - cmd.SetOut(&out) - if err := printRows(cmd, rows); err != nil { - t.Fatalf("printRows returned error: %v", err) - } - got := out.String() - if !strings.Contains(got, "OID\tSTATUS\tPATH\tDETAIL\n") { - t.Fatalf("missing header in output: %q", got) - } - if !strings.Contains(got, rows[0].OID+"\tpresent\ta/file1.bin\tdrs://did-1\n") { - t.Fatalf("missing present row: %q", got) - } - if !strings.Contains(got, rows[1].OID+"\tmissing\tb/file2.bin\t-\n") { - t.Fatalf("missing missing row: %q", got) - } - if !strings.Contains(got, rows[2].OID+"\terror\tc/file3.bin\tlookup failed\n") { - t.Fatalf("missing error row: %q", got) + resetFlagsForTest() + nameOnly = true + showLong = true + if err := validateOutputFlags(); err == nil { + t.Fatal("expected long/name-only conflict") } } diff --git a/cmd/lsfiles/output.go b/cmd/lsfiles/output.go new file mode 100644 index 00000000..f855b9f5 --- /dev/null +++ b/cmd/lsfiles/output.go @@ -0,0 +1,45 @@ +package lsfiles + +import ( + "encoding/json" + "fmt" + + "github.com/spf13/cobra" +) + +func printRows(cmd *cobra.Command, rows []fileRow) error { + if jsonOutput { + enc := json.NewEncoder(cmd.OutOrStdout()) + enc.SetIndent("", " ") + return enc.Encode(rows) + } + for _, row := range rows { + switch { + case nameOnly: + if _, err := fmt.Fprintln(cmd.OutOrStdout(), row.Path); err != nil { + return err + } + case drsStatus: + oid := row.ShortOID + if showLong { + oid = row.OID + } + detail := row.Detail + if detail == "" { + detail = "-" + } + if _, err := fmt.Fprintf(cmd.OutOrStdout(), "%s %s %s\t%s\n", oid, row.Status, row.Path, detail); err != nil { + return err + } + default: + oid := row.ShortOID + if showLong { + oid = row.OID + } + if _, err := fmt.Fprintf(cmd.OutOrStdout(), "%s %s %s\n", oid, row.Status, row.Path); err != nil { + return err + } + } + } + return nil +} diff --git a/cmd/lsfiles/service.go b/cmd/lsfiles/service.go new file mode 100644 index 00000000..91b5d172 --- /dev/null +++ b/cmd/lsfiles/service.go @@ -0,0 +1,309 @@ +package lsfiles + +import ( + "context" + "fmt" + "log/slog" + "os" + "os/exec" + "path/filepath" + "regexp" + "sort" + "strings" + + "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/drslog" + "github.com/calypr/git-drs/internal/lfs" + "github.com/calypr/git-drs/internal/lookup" + "github.com/calypr/git-drs/internal/remoteruntime" + drsapi "github.com/calypr/syfon/apigen/client/drs" +) + +type fileRow struct { + OID string `json:"oid"` + ShortOID string `json:"short_oid"` + Status string `json:"status"` + Path string `json:"path"` + Localized bool `json:"localized"` + Registered bool `json:"registered,omitempty"` + DRSIDs []string `json:"drs_ids,omitempty"` + Detail string `json:"detail,omitempty"` +} + +var ( + loadConfig = config.LoadConfig + resolveRemote = func(cfg *config.Config, name string) (config.Remote, error) { return cfg.GetRemoteOrDefault(name) } + newRemoteClient = func(cfg *config.Config, remote config.Remote, logger *slog.Logger) (*remoteruntime.GitContext, error) { + return remoteruntime.New(cfg, remote, logger) + } + loadLFSInventory = func(gitRemoteName, gitRemoteLocation string, branches []string, logger *slog.Logger) (map[string]lfs.LfsFileInfo, error) { + if len(branches) == 0 { + return lfs.GetTrackedLfsFiles(logger) + } + return lfs.GetLfsFilesForRefs(branches, logger) + } + listRemoteRefs = defaultListRemoteRefs + listGitRemotes = defaultListGitRemotes + resolveDefaultRemote = defaultResolveDefaultRemote + lookupScopedObjectsBatch = lookup.ObjectsByHashesForScope +) + +func defaultListRemoteRefs(gitRemoteName string) ([]string, error) { + if strings.TrimSpace(gitRemoteName) == "" { + return nil, nil + } + + cmd := exec.Command("git", "for-each-ref", "--format=%(refname)", "refs/remotes/"+gitRemoteName) + out, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("list refs for remote %s: %w", gitRemoteName, err) + } + + lines := strings.Split(string(out), "\n") + refs := make([]string, 0, len(lines)) + for _, line := range lines { + ref := strings.TrimSpace(line) + if ref == "" || strings.HasSuffix(ref, "/HEAD") { + continue + } + refs = append(refs, ref) + } + sort.Strings(refs) + return refs, nil +} + +func defaultListGitRemotes() ([]string, error) { + cmd := exec.Command("git", "remote") + out, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("list git remotes: %w", err) + } + + lines := strings.Split(string(out), "\n") + remotes := make([]string, 0, len(lines)) + for _, line := range lines { + name := strings.TrimSpace(line) + if name == "" { + continue + } + remotes = append(remotes, name) + } + sort.Strings(remotes) + return remotes, nil +} + +func defaultResolveDefaultRemote() string { + cfg, err := loadConfig() + if err == nil && cfg != nil { + if remote, err := cfg.GetRemoteOrDefault(""); err == nil { + return strings.TrimSpace(string(remote)) + } + } + + remotes, err := listGitRemotes() + if err != nil || len(remotes) == 0 { + return "" + } + for _, remote := range remotes { + if remote == config.ORIGIN { + return remote + } + } + if len(remotes) == 1 { + return remotes[0] + } + return "" +} + +func collectRows(ctx context.Context, gitRemoteName, drsRemoteName string, patterns []string, resolveDRS bool) ([]fileRow, error) { + logger := drslog.GetLogger() + + var client *remoteruntime.GitContext + if resolveDRS { + cfg, err := loadConfig() + if err != nil { + return nil, err + } + + remoteName, err := resolveRemote(cfg, drsRemoteName) + if err != nil { + logger.Error(fmt.Sprintf("Error getting remote: %v", err)) + return nil, err + } + + client, err = newRemoteClient(cfg, remoteName, logger) + if err != nil { + return nil, err + } + } + + var ( + refs []string + err error + ) + if strings.TrimSpace(gitRemoteName) != "" { + refs, err = listRemoteRefs(gitRemoteName) + if err != nil { + return nil, err + } + } + + lfsFiles, err := loadLFSInventory(gitRemoteName, drsRemoteName, refs, logger) + if err != nil { + return nil, err + } + if len(lfsFiles) == 0 && strings.TrimSpace(gitRemoteName) == "" { + fallbackRemote := resolveDefaultRemote() + if fallbackRemote != "" { + refs, err = listRemoteRefs(fallbackRemote) + if err != nil { + return nil, err + } + if len(refs) > 0 { + lfsFiles, err = loadLFSInventory(fallbackRemote, drsRemoteName, refs, logger) + if err != nil { + return nil, err + } + } + } + } + + keys := make([]string, 0, len(lfsFiles)) + for path := range lfsFiles { + keys = append(keys, path) + } + sort.Strings(keys) + + rows := make([]fileRow, 0, len(keys)) + var drsResults map[string][]drsapi.DrsObject + var drsLookupErr error + if resolveDRS { + oids := make([]string, 0, len(keys)) + seenOIDs := make(map[string]struct{}, len(keys)) + for _, path := range keys { + if !matchesAnyPattern(path, patterns) { + continue + } + oid := lfsFiles[path].Oid + if oid == "" { + continue + } + if _, exists := seenOIDs[oid]; exists { + continue + } + seenOIDs[oid] = struct{}{} + oids = append(oids, oid) + } + drsResults, drsLookupErr = lookupScopedObjectsBatch(ctx, client, oids) + } + for _, path := range keys { + if !matchesAnyPattern(path, patterns) { + continue + } + info := lfsFiles[path] + row := fileRow{ + OID: info.Oid, + ShortOID: shortOID(info.Oid), + Path: path, + Localized: isLocalized(path), + } + row.Status = "-" + if row.Localized { + row.Status = "*" + } + + if resolveDRS { + switch { + case drsLookupErr != nil: + row.Detail = drsLookupErr.Error() + default: + results := drsResults[info.Oid] + if len(results) == 0 { + row.Registered = false + break + } + row.Registered = true + row.DRSIDs = make([]string, 0, len(results)) + for _, res := range results { + row.DRSIDs = append(row.DRSIDs, "drs://"+res.Id) + } + row.Detail = strings.Join(row.DRSIDs, ",") + } + } + + rows = append(rows, row) + } + + return rows, nil +} + +func shortOID(oid string) string { + if len(oid) <= 10 { + return oid + } + return oid[:10] +} + +func matchesAnyPattern(path string, patterns []string) bool { + if len(patterns) == 0 { + return true + } + normalized := filepath.ToSlash(filepath.Clean(path)) + for _, pattern := range patterns { + pattern = strings.TrimSpace(pattern) + if pattern == "" { + continue + } + if matchesPattern(normalized, pattern) { + return true + } + } + return false +} + +func matchesPattern(path, pattern string) bool { + pattern = filepath.ToSlash(filepath.Clean(pattern)) + if !strings.ContainsAny(pattern, "*?[") { + return path == pattern + } + re, err := regexp.Compile(globToRegexp(pattern)) + if err != nil { + return false + } + return re.MatchString(path) +} + +func globToRegexp(pattern string) string { + var b strings.Builder + b.WriteString("^") + for i := 0; i < len(pattern); i++ { + ch := pattern[i] + switch ch { + case '*': + if i+1 < len(pattern) && pattern[i+1] == '*' { + b.WriteString(".*") + i++ + continue + } + b.WriteString(`[^/]*`) + case '?': + b.WriteString(`[^/]`) + case '.', '+', '(', ')', '|', '^', '$', '{', '}', '[', ']', '\\': + b.WriteByte('\\') + b.WriteByte(ch) + default: + b.WriteByte(ch) + } + } + b.WriteString("$") + return b.String() +} + +func isLocalized(path string) bool { + payload, err := os.ReadFile(path) + if err != nil { + return false + } + _, _, ok := lfs.ParseLFSPointer(payload) + return !ok +} diff --git a/cmd/ping/main.go b/cmd/ping/main.go new file mode 100644 index 00000000..e4136aae --- /dev/null +++ b/cmd/ping/main.go @@ -0,0 +1,248 @@ +package ping + +import ( + "context" + "fmt" + "log/slog" + "strings" + + "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/drslog" + "github.com/calypr/git-drs/internal/remoteruntime" + bucketapi "github.com/calypr/syfon/apigen/client/bucketapi" + syfoncommon "github.com/calypr/syfon/common" + "github.com/spf13/cobra" +) + +type statusInfo struct { + Remote config.Remote + IsDefault bool + RemoteType string + Endpoint string + Organization string + Project string + Bucket string + StoragePrefix string + AuthMode string +} + +var pingHealth = func(ctx context.Context, gc *remoteruntime.GitContext) error { + return gc.Client.Health().Ping(ctx) +} + +var pingScopeAccess = func(ctx context.Context, gc *remoteruntime.GitContext) (scopeAccessInfo, error) { + return checkScopeAccess(ctx, gc) +} + +type scopeAccessInfo struct { + Checked bool + VisibleBucket string + ProjectReadable bool +} + +var Cmd = &cobra.Command{ + Use: "ping [remote-name]", + Short: "Show effective remote setup and verify the remote responds", + Args: func(cmd *cobra.Command, args []string) error { + if len(args) > 1 { + cmd.SilenceUsage = false + return fmt.Errorf("error: accepts at most 1 argument (remote name), received %d\n\nUsage: %s\n\nSee 'git drs ping --help' for more details", len(args), cmd.UseLine()) + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + logger := drslog.GetLogger() + status, gc, err := resolveStatus(args, logger) + if err != nil { + return err + } + printStatus(status) + + if err := pingHealth(cmd.Context(), gc); err != nil { + return fmt.Errorf("remote health check failed for %q (%s): %w", status.Remote, status.Endpoint, err) + } + fmt.Println("health: ok") + + scopeInfo, err := pingScopeAccess(cmd.Context(), gc) + if err != nil { + return fmt.Errorf("configured scope access check failed for remote %q (organization=%s project=%s bucket=%s): %w", + status.Remote, + blankIfEmpty(status.Organization), + blankIfEmpty(status.Project), + blankIfEmpty(status.Bucket), + err, + ) + } + if scopeInfo.Checked { + fmt.Println("scope_access: ok") + if strings.TrimSpace(scopeInfo.VisibleBucket) != "" { + fmt.Printf("visible_bucket: %s\n", scopeInfo.VisibleBucket) + } + if scopeInfo.ProjectReadable { + fmt.Println("project_access: readable") + } + } + return nil + }, +} + +func resolveStatus(args []string, logger *slog.Logger) (statusInfo, *remoteruntime.GitContext, error) { + cfg, err := config.LoadConfig() + if err != nil { + return statusInfo{}, nil, err + } + + var remoteArg string + if len(args) == 1 { + remoteArg = args[0] + } + remoteName, err := cfg.GetRemoteOrDefault(remoteArg) + if err != nil { + return statusInfo{}, nil, err + } + + remoteCfg := cfg.GetRemote(remoteName) + if remoteCfg == nil { + return statusInfo{}, nil, fmt.Errorf("no remote configuration found for %q", remoteName) + } + + gc, err := remoteruntime.New(cfg, remoteName, logger) + if err != nil { + return statusInfo{}, nil, err + } + + status := statusInfo{ + Remote: remoteName, + IsDefault: remoteName == cfg.DefaultRemote, + Endpoint: remoteCfg.GetEndpoint(), + Organization: remoteCfg.GetOrganization(), + Project: remoteCfg.GetProjectId(), + Bucket: gc.BucketName, + StoragePrefix: gc.StoragePrefix, + AuthMode: authMode(gc), + } + switch remoteCfg.(type) { + case *config.Gen3Remote: + status.RemoteType = string(config.Gen3ServerType) + case *config.LocalRemote: + status.RemoteType = string(config.LocalServerType) + default: + status.RemoteType = "unknown" + } + + return status, gc, nil +} + +func printStatus(status statusInfo) { + def := "" + if status.IsDefault { + def = " (default)" + } + fmt.Printf("remote: %s%s\n", status.Remote, def) + fmt.Printf("type: %s\n", status.RemoteType) + fmt.Printf("endpoint: %s\n", status.Endpoint) + fmt.Printf("organization: %s\n", blankIfEmpty(status.Organization)) + fmt.Printf("project: %s\n", blankIfEmpty(status.Project)) + fmt.Printf("bucket: %s\n", blankIfEmpty(status.Bucket)) + fmt.Printf("storage_prefix: %s\n", blankIfEmpty(status.StoragePrefix)) + fmt.Printf("auth: %s\n", status.AuthMode) +} + +func authMode(gc *remoteruntime.GitContext) string { + if gc == nil || gc.Credential == nil { + return "none" + } + if strings.TrimSpace(gc.Credential.AccessToken) != "" { + return "bearer" + } + if strings.TrimSpace(gc.Credential.KeyID) != "" || strings.TrimSpace(gc.Credential.APIKey) != "" { + return "basic" + } + return "none" +} + +func blankIfEmpty(v string) string { + v = strings.TrimSpace(v) + if v == "" { + return "-" + } + return v +} + +func checkScopeAccess(ctx context.Context, gc *remoteruntime.GitContext) (scopeAccessInfo, error) { + if gc == nil || gc.Client == nil { + return scopeAccessInfo{}, fmt.Errorf("DRS client unavailable") + } + + info := scopeAccessInfo{} + organization := strings.TrimSpace(gc.Organization) + project := strings.TrimSpace(gc.ProjectId) + bucket := strings.TrimSpace(gc.BucketName) + + if organization == "" && project == "" && bucket == "" { + return info, nil + } + info.Checked = true + + if organization != "" && project != "" { + visibleBucket, err := visibleBucketForScope(ctx, gc, organization, project) + if err != nil { + return scopeAccessInfo{}, err + } + info.VisibleBucket = visibleBucket + if bucket != "" && !strings.EqualFold(strings.TrimSpace(visibleBucket), bucket) { + return scopeAccessInfo{}, fmt.Errorf("server exposes bucket %q for configured scope, but repo is configured for bucket %q", visibleBucket, bucket) + } + } + + if project != "" { + if _, err := gc.Client.DRS().GetProjectSample(ctx, project, 1); err != nil { + return scopeAccessInfo{}, fmt.Errorf("project listing failed: %w", err) + } + info.ProjectReadable = true + } + + return info, nil +} + +func visibleBucketForScope(ctx context.Context, gc *remoteruntime.GitContext, organization, project string) (string, error) { + payload, err := gc.Client.Buckets().List(ctx) + if err != nil { + return "", fmt.Errorf("bucket visibility lookup failed: %w", err) + } + + resource, err := syfoncommon.ResourcePath(organization, project) + if err != nil { + return "", fmt.Errorf("build scope resource path: %w", err) + } + + matches := findBucketsByResource(payload, resource) + if len(matches) == 0 { + return "", fmt.Errorf("no visible server bucket matched configured scope %s", resource) + } + if len(matches) > 1 { + return "", fmt.Errorf("multiple visible server buckets matched configured scope %s: %s", resource, strings.Join(matches, ", ")) + } + return matches[0], nil +} + +func findBucketsByResource(payload bucketapi.BucketsResponse, resource string) []string { + resource = syfoncommon.NormalizeAccessResource(resource) + if resource == "" { + return nil + } + + matches := make([]string, 0) + for bucket, meta := range payload.S3BUCKETS { + if meta.Programs == nil { + continue + } + for _, candidate := range *meta.Programs { + if syfoncommon.NormalizeAccessResource(candidate) == resource { + matches = append(matches, bucket) + break + } + } + } + return matches +} diff --git a/cmd/ping/main_test.go b/cmd/ping/main_test.go new file mode 100644 index 00000000..eb1c8141 --- /dev/null +++ b/cmd/ping/main_test.go @@ -0,0 +1,197 @@ +package ping + +import ( + "bytes" + "context" + "errors" + "io" + "os" + "strings" + "testing" + + "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/drslog" + "github.com/calypr/git-drs/internal/gitrepo" + "github.com/calypr/git-drs/internal/remoteruntime" + "github.com/calypr/git-drs/internal/testutils" +) + +func TestPingCmdArgs(t *testing.T) { + if err := Cmd.Args(Cmd, nil); err != nil { + t.Fatalf("unexpected error with no args: %v", err) + } + if err := Cmd.Args(Cmd, []string{"origin"}); err != nil { + t.Fatalf("unexpected error with one arg: %v", err) + } + if err := Cmd.Args(Cmd, []string{"origin", "extra"}); err == nil { + t.Fatal("expected error for extra args") + } +} + +func TestResolveStatusLocalRemote(t *testing.T) { + tmpDir := testutils.SetupTestGitRepo(t) + testutils.CreateTestConfig(t, tmpDir, &config.Config{ + DefaultRemote: config.Remote(config.ORIGIN), + Remotes: map[config.Remote]config.RemoteSelect{ + config.Remote(config.ORIGIN): { + Local: &config.LocalRemote{ + BaseURL: "http://127.0.0.1:8080", + ProjectID: "end_to_end_test", + Bucket: "cbds", + Organization: "calypr", + BasicUsername: "drs-user", + BasicPassword: "drs-pass", + }, + }, + }, + }) + if err := gitrepo.SetBucketMapping("calypr", "end_to_end_test", "cbds", "prefix"); err != nil { + t.Fatalf("SetBucketMapping failed: %v", err) + } + + status, _, err := resolveStatus(nil, drslog.NewNoOpLogger()) + if err != nil { + t.Fatalf("resolveStatus returned error: %v", err) + } + if status.Remote != "origin" || !status.IsDefault { + t.Fatalf("unexpected remote selection: %+v", status) + } + if status.RemoteType != "local" || status.Endpoint != "http://127.0.0.1:8080" { + t.Fatalf("unexpected remote type/endpoint: %+v", status) + } + if status.Organization != "calypr" || status.Project != "end_to_end_test" { + t.Fatalf("unexpected scope: %+v", status) + } + if status.Bucket != "cbds" || status.StoragePrefix != "prefix" { + t.Fatalf("unexpected bucket scope: %+v", status) + } + if status.AuthMode != "none" { + t.Fatalf("expected auth mode none from client credential shape, got %+v", status) + } +} + +func TestPingRunEPrintsStatusAndHealth(t *testing.T) { + tmpDir := testutils.SetupTestGitRepo(t) + testutils.CreateTestConfig(t, tmpDir, &config.Config{ + DefaultRemote: config.Remote(config.ORIGIN), + Remotes: map[config.Remote]config.RemoteSelect{ + config.Remote(config.ORIGIN): { + Local: &config.LocalRemote{ + BaseURL: "http://127.0.0.1:8080", + ProjectID: "end_to_end_test", + Bucket: "cbds", + Organization: "calypr", + }, + }, + }, + }) + if err := gitrepo.SetBucketMapping("calypr", "end_to_end_test", "cbds", "prefix"); err != nil { + t.Fatalf("SetBucketMapping failed: %v", err) + } + + oldHealth := pingHealth + pingHealth = func(ctx context.Context, gc *remoteruntime.GitContext) error { + if gc == nil || gc.ProjectId != "end_to_end_test" { + t.Fatalf("unexpected git context: %+v", gc) + } + return nil + } + t.Cleanup(func() { pingHealth = oldHealth }) + + oldScopeAccess := pingScopeAccess + pingScopeAccess = func(ctx context.Context, gc *remoteruntime.GitContext) (scopeAccessInfo, error) { + if gc == nil || gc.ProjectId != "end_to_end_test" { + t.Fatalf("unexpected git context for scope probe: %+v", gc) + } + return scopeAccessInfo{ + Checked: true, + VisibleBucket: "cbds", + ProjectReadable: true, + }, nil + } + t.Cleanup(func() { pingScopeAccess = oldScopeAccess }) + + oldStdout := os.Stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("pipe: %v", err) + } + os.Stdout = w + t.Cleanup(func() { os.Stdout = oldStdout }) + + runErr := Cmd.RunE(Cmd, nil) + _ = w.Close() + if runErr != nil { + t.Fatalf("Cmd.RunE returned error: %v", runErr) + } + + var buf bytes.Buffer + if _, err := io.Copy(&buf, r); err != nil { + t.Fatalf("read stdout: %v", err) + } + got := buf.String() + for _, want := range []string{ + "remote: origin (default)", + "type: local", + "endpoint: http://127.0.0.1:8080", + "organization: calypr", + "project: end_to_end_test", + "bucket: cbds", + "storage_prefix: prefix", + "health: ok", + "scope_access: ok", + "visible_bucket: cbds", + "project_access: readable", + } { + if !strings.Contains(got, want) { + t.Fatalf("expected output to contain %q, got %q", want, got) + } + } +} + +func TestPingRunEReturnsReadableScopeError(t *testing.T) { + tmpDir := testutils.SetupTestGitRepo(t) + testutils.CreateTestConfig(t, tmpDir, &config.Config{ + DefaultRemote: config.Remote(config.ORIGIN), + Remotes: map[config.Remote]config.RemoteSelect{ + config.Remote(config.ORIGIN): { + Local: &config.LocalRemote{ + BaseURL: "http://127.0.0.1:8080", + ProjectID: "end_to_end_test", + Bucket: "cbds", + Organization: "calypr", + }, + }, + }, + }) + if err := gitrepo.SetBucketMapping("calypr", "end_to_end_test", "cbds", "prefix"); err != nil { + t.Fatalf("SetBucketMapping failed: %v", err) + } + + oldHealth := pingHealth + pingHealth = func(ctx context.Context, gc *remoteruntime.GitContext) error { return nil } + t.Cleanup(func() { pingHealth = oldHealth }) + + oldScopeAccess := pingScopeAccess + pingScopeAccess = func(ctx context.Context, gc *remoteruntime.GitContext) (scopeAccessInfo, error) { + return scopeAccessInfo{}, errors.New("bucket visibility lookup failed: unexpected response: 403: denied") + } + t.Cleanup(func() { pingScopeAccess = oldScopeAccess }) + + err := Cmd.RunE(Cmd, nil) + if err == nil { + t.Fatal("expected scope access error") + } + got := err.Error() + for _, want := range []string{ + "configured scope access check failed", + "organization=calypr", + "project=end_to_end_test", + "bucket=cbds", + "bucket visibility lookup failed: unexpected response: 403: denied", + } { + if !strings.Contains(got, want) { + t.Fatalf("expected error to contain %q, got %q", want, got) + } + } +} diff --git a/cmd/precommit/cache_apply.go b/cmd/precommit/cache_apply.go new file mode 100644 index 00000000..cb72d610 --- /dev/null +++ b/cmd/precommit/cache_apply.go @@ -0,0 +1,102 @@ +package precommit + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/calypr/git-drs/internal/precommit_cache" +) + +func handleUpsert(ctx context.Context, cache *precommit_cache.Cache, path, now string) error { + oid, isLFS, err := stagedLFSOID(ctx, path) + if err != nil { + return nil + } + if !isLFS { + return nil + } + + prev, prevExists, err := precommit_cache.ReadPathEntry(cache, path) + if err != nil { + return err + } + + if err := precommit_cache.WritePathEntry(cache, precommit_cache.PathEntry{ + Path: path, + LFSOID: oid, + UpdatedAt: now, + }); err != nil { + return err + } + + contentChanged := prevExists && prev != nil && prev.LFSOID != oid + if err := precommit_cache.UpsertOIDPath(cache, oid, "", path, "", now, contentChanged); err != nil { + return err + } + if contentChanged { + if err := precommit_cache.RemoveOIDPath(cache, prev.LFSOID, path, now); err != nil { + return fmt.Errorf("remove stale OID path mapping for %s: %w", path, err) + } + } + + return nil +} + +func handleDelete(_ context.Context, cache *precommit_cache.Cache, tombsDir, path, now string) error { + entry, ok, err := precommit_cache.ReadPathEntry(cache, path) + if err != nil || !ok { + return nil + } + if err := removeIfExists(precommit_cache.PathEntryPath(cache, path)); err != nil { + return fmt.Errorf("remove path entry for %s: %w", path, err) + } + if entry.LFSOID != "" { + if err := precommit_cache.RemoveOIDPath(cache, entry.LFSOID, path, now); err != nil { + return fmt.Errorf("remove OID path mapping for %s: %w", path, err) + } + } + + tombFile := filepath.Join(tombsDir, precommit_cache.EncodePath(path)+".json") + if err := writeJSONAtomic(tombFile, map[string]string{ + "path": path, + "deleted_at": now, + }); err != nil { + return fmt.Errorf("write tombstone for %s: %w", path, err) + } + + return nil +} + +func writeJSONAtomic(path string, v any) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + + tmp := path + ".tmp" + f, err := os.OpenFile(tmp, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) + if err != nil { + return err + } + + enc := json.NewEncoder(f) + enc.SetIndent("", " ") + if err := enc.Encode(v); err != nil { + _ = f.Close() + _ = removeIfExists(tmp) + return err + } + if err := f.Sync(); err != nil { + _ = f.Close() + _ = removeIfExists(tmp) + return err + } + if err := f.Close(); err != nil { + _ = removeIfExists(tmp) + return err + } + return os.Rename(tmp, path) +} diff --git a/cmd/precommit/changes.go b/cmd/precommit/changes.go new file mode 100644 index 00000000..e3d804a9 --- /dev/null +++ b/cmd/precommit/changes.go @@ -0,0 +1,91 @@ +package precommit + +import ( + "bufio" + "bytes" + "context" + "fmt" + "strconv" + "strings" +) + +// stagedChanges parses: git diff --cached --name-status -M +func stagedChanges(ctx context.Context) ([]Change, error) { + out, err := git(ctx, "diff", "--cached", "--name-status", "-M") + if err != nil { + return nil, err + } + var changes []Change + sc := bufio.NewScanner(bytes.NewReader(out)) + for sc.Scan() { + line := sc.Text() + if strings.TrimSpace(line) == "" { + continue + } + parts := strings.Split(line, "\t") + if len(parts) < 2 { + continue + } + status := parts[0] + switch { + case status == "A": + changes = append(changes, Change{Kind: KindAdd, NewPath: parts[1], Status: status}) + case status == "M": + changes = append(changes, Change{Kind: KindModify, NewPath: parts[1], Status: status}) + case status == "D": + changes = append(changes, Change{Kind: KindDelete, NewPath: parts[1], Status: status}) + case strings.HasPrefix(status, "R") && len(parts) >= 3: + changes = append(changes, Change{Kind: KindRename, OldPath: parts[1], NewPath: parts[2], Status: status}) + } + } + if err := sc.Err(); err != nil { + return nil, err + } + return changes, nil +} + +func stagedLFSOID(ctx context.Context, path string) (string, bool, error) { + out, err := git(ctx, "show", ":"+path) + if err != nil { + return "", false, err + } + + var hasSpec bool + var oid string + sc := bufio.NewScanner(bytes.NewReader(out)) + for sc.Scan() { + line := sc.Text() + if line == lfsSpecLine { + hasSpec = true + continue + } + if strings.HasPrefix(line, "oid sha256:") { + hex := strings.TrimSpace(strings.TrimPrefix(line, "oid sha256:")) + if hex != "" { + oid = "sha256:" + hex + } + } + if hasSpec && oid != "" { + break + } + } + if err := sc.Err(); err != nil { + return "", false, err + } + if hasSpec && oid != "" { + return oid, true, nil + } + return "", false, nil +} + +func stagedBlobSize(ctx context.Context, path string) (int64, error) { + out, err := git(ctx, "cat-file", "-s", ":"+path) + if err != nil { + return 0, err + } + size, err := strconv.ParseInt(strings.TrimSpace(string(out)), 10, 64) + if err != nil { + return 0, fmt.Errorf("parse staged blob size for %s: %w", path, err) + } + return size, nil +} diff --git a/cmd/precommit/git.go b/cmd/precommit/git.go new file mode 100644 index 00000000..f2d29599 --- /dev/null +++ b/cmd/precommit/git.go @@ -0,0 +1,64 @@ +package precommit + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" +) + +func git(ctx context.Context, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, "git", args...) + cmd.Env = os.Environ() + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + msg := strings.TrimSpace(stderr.String()) + if msg == "" { + msg = err.Error() + } + return nil, fmt.Errorf("git %s: %s", strings.Join(args, " "), msg) + } + return stdout.Bytes(), nil +} + +func moveFileBestEffort(src, dst string) error { + if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + return err + } + if err := os.Rename(src, dst); err == nil { + return nil + } + + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + + out, err := os.OpenFile(dst, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) + if err != nil { + return err + } + if _, err := io.Copy(out, in); err != nil { + _ = out.Close() + return err + } + if err := out.Close(); err != nil { + return err + } + return os.Remove(src) +} + +func removeIfExists(path string) error { + err := os.Remove(path) + if err == nil || os.IsNotExist(err) { + return nil + } + return err +} diff --git a/cmd/precommit/main.go b/cmd/precommit/main.go index ac2f3047..77dadb70 100644 --- a/cmd/precommit/main.go +++ b/cmd/precommit/main.go @@ -1,72 +1,17 @@ -// Package precommit -// ------------------------------------- -// LFS-only local cache updater for: -// - Path -> OID : .git/drs/pre-commit/v1/paths/.json -// - OID -> Paths + S3 URL hint : .git/drs/pre-commit/v1/oids/.json -// -// This hook is intentionally: -// - LFS-only (non-LFS paths are ignored) -// - local-only (no network, no server index reads) -// - index-based (reads STAGED content via `git show :`) -// -// Note: This is a reference implementation. Adjust logging/policy as desired. package precommit -import ( - "bufio" - "bytes" - "context" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "os" - "os/exec" - "path/filepath" - "sort" - "strings" - "time" - - "github.com/spf13/cobra" -) +import "github.com/spf13/cobra" const ( - cacheVersionDir = "drs/pre-commit/v1" - lfsSpecLine = "version https://git-lfs.github.com/spec/v1" + lfsSpecLine = "version https://git-lfs.github.com/spec/v1" + defaultDirectCommitWarningThreshold = int64(10 * 1024 * 1024) ) -type PathEntry struct { - Path string `json:"path"` - LFSOID string `json:"lfs_oid"` - UpdatedAt string `json:"updated_at"` -} - -type OIDEntry struct { - LFSOID string `json:"lfs_oid"` - Paths []string `json:"paths"` - S3URL string `json:"s3_url,omitempty"` // hint only; may be empty - UpdatedAt string `json:"updated_at"` - ContentChange bool `json:"content_changed"` -} - -type ChangeKind int - -const ( - KindAdd ChangeKind = iota - KindModify - KindDelete - KindRename +var ( + directCommitWarningThresholdBytes = defaultDirectCommitWarningThreshold + confirmOversizedDirectGitCommit = promptOversizedDirectGitCommit ) -type Change struct { - Kind ChangeKind - OldPath string // for rename - NewPath string // for rename (and for add/modify/delete uses NewPath) - Status string // raw status, e.g. "A", "M", "D", "R100" -} - // Cmd line declaration var Cmd = &cobra.Command{ Use: "precommit", @@ -74,474 +19,6 @@ var Cmd = &cobra.Command{ Long: "Pre-commit hook that updates the local DRS pre-commit cache", Args: cobra.ExactArgs(0), RunE: func(cmd *cobra.Command, args []string) error { - return run(context.Background()) + return run(cmd.Context()) }, } - -func main() { - ctx := context.Background() - if err := run(ctx); err != nil { - // For a reference impl, treat errors as non-fatal unless you want strict enforcement. - // Exiting non-zero blocks the commit. - fmt.Fprintf(os.Stderr, "pre-commit drs cache: %v\n", err) - os.Exit(1) - } -} - -func run(ctx context.Context) error { - gitDir, err := gitRevParseGitDir(ctx) - if err != nil { - return err - } - - cacheRoot := filepath.Join(gitDir, cacheVersionDir) - pathsDir := filepath.Join(cacheRoot, "paths") - oidsDir := filepath.Join(cacheRoot, "oids") - tombsDir := filepath.Join(cacheRoot, "tombstones") - - if err := os.MkdirAll(pathsDir, 0o755); err != nil { - return err - } - if err := os.MkdirAll(oidsDir, 0o755); err != nil { - return err - } - _ = os.MkdirAll(tombsDir, 0o755) // optional - - changes, err := stagedChanges(ctx) - if err != nil { - return err - } - if len(changes) == 0 { - return nil - } - - now := time.Now().UTC().Format(time.RFC3339) - - // Process renames first so subsequent add/modify logic sees the "new" path. - // This mirrors how we want cache paths to follow staged paths. - for _, ch := range changes { - if ch.Kind != KindRename { - continue - } - // Only act if BOTH old and new are LFS in scope? Prefer: - // - If the new path is LFS, we migrate. - // - If it isn't LFS, we remove old path entry (out of scope). - newOID, newIsLFS, err := stagedLFSOID(ctx, ch.NewPath) - if err != nil { - // If file doesn't exist in index due to weird staging, skip. - continue - } - - oldPathFile := pathEntryFile(pathsDir, ch.OldPath) - newPathFile := pathEntryFile(pathsDir, ch.NewPath) - - if newIsLFS { - // Move/overwrite path entry file - if err := moveFileBestEffort(oldPathFile, newPathFile); err != nil && !errors.Is(err, os.ErrNotExist) { - return fmt.Errorf("rename migrate path entry: %w", err) - } - - // Ensure path entry content correct - if err := writeJSONAtomic(newPathFile, PathEntry{ - Path: ch.NewPath, - LFSOID: newOID, - UpdatedAt: now, - }); err != nil { - return err - } - - // Update oid entry: replace old path with new path for that OID - if err := oidAddOrReplacePath(oidsDir, newOID, ch.OldPath, ch.NewPath, now, false); err != nil { - return err - } - } else { - // Out of scope now: remove any cached path entry. - _ = os.Remove(oldPathFile) - } - } - - // Process adds/modifies/deletes (and renames again just to ensure content correctness on new path). - for _, ch := range changes { - switch ch.Kind { - case KindAdd, KindModify: - if err := handleUpsert(ctx, pathsDir, oidsDir, ch.NewPath, now); err != nil { - return err - } - case KindRename: - // Treat like upsert on NewPath to ensure OID/path consistency if content also changed. - if err := handleUpsert(ctx, pathsDir, oidsDir, ch.NewPath, now); err != nil { - return err - } - // Optionally also remove old path from *other* OID entry if rename+content-change changed OID. - // We'll do it inside handleUpsert by checking previous cached OID for that path (after move). - case KindDelete: - if err := handleDelete(ctx, pathsDir, oidsDir, tombsDir, ch.NewPath, now); err != nil { - return err - } - } - } - - return nil -} - -func handleUpsert(ctx context.Context, pathsDir, oidsDir, path, now string) error { - oid, isLFS, err := stagedLFSOID(ctx, path) - if err != nil { - // If file isn't in index, ignore. - return nil - } - if !isLFS { - // Out of scope. - return nil - } - - pathFile := pathEntryFile(pathsDir, path) - - // Load previous path entry if it exists to detect content changes. - var prev PathEntry - prevExists := false - if b, err := os.ReadFile(pathFile); err == nil { - _ = json.Unmarshal(b, &prev) - if prev.Path != "" && prev.LFSOID != "" { - prevExists = true - } - } - - // Write/update path entry. - if err := writeJSONAtomic(pathFile, PathEntry{ - Path: path, - LFSOID: oid, - UpdatedAt: now, - }); err != nil { - return err - } - - // Update OID entry for new oid: add path. - contentChanged := prevExists && prev.LFSOID != oid - if err := oidAddOrReplacePath(oidsDir, oid, "", path, now, contentChanged); err != nil { - return err - } - - // If content changed, remove path from the *old* oid entry (best effort). - if contentChanged { - _ = oidRemovePath(oidsDir, prev.LFSOID, path, now) - } - - return nil -} - -func handleDelete(ctx context.Context, pathsDir, oidsDir, tombsDir, path, now string) error { - // Only consider deletion if it was previously an LFS entry (cache-driven). - pathFile := pathEntryFile(pathsDir, path) - b, err := os.ReadFile(pathFile) - if err != nil { - // nothing to do - return nil - } - var pe PathEntry - if err := json.Unmarshal(b, &pe); err != nil { - // corrupted cache; remove it - _ = os.Remove(pathFile) - return nil - } - // Remove path entry. - _ = os.Remove(pathFile) - - // Remove this path from the old oid entry (best effort). - if pe.LFSOID != "" { - _ = oidRemovePath(oidsDir, pe.LFSOID, path, now) - } - - // Optional tombstone. - tombFile := filepath.Join(tombsDir, encodePath(path)+".json") - _ = writeJSONAtomic(tombFile, map[string]string{ - "path": path, - "deleted_at": now, - }) - - return nil -} - -// stagedChanges parses: git diff --cached --name-status -M -// Formats: -// -// Apath -// Mpath -// Dpath -// R100oldnew -func stagedChanges(ctx context.Context) ([]Change, error) { - out, err := git(ctx, "diff", "--cached", "--name-status", "-M") - if err != nil { - return nil, err - } - var changes []Change - sc := bufio.NewScanner(bytes.NewReader(out)) - for sc.Scan() { - line := sc.Text() - if strings.TrimSpace(line) == "" { - continue - } - parts := strings.Split(line, "\t") - if len(parts) < 2 { - continue - } - status := parts[0] - switch { - case status == "A": - changes = append(changes, Change{Kind: KindAdd, NewPath: parts[1], Status: status}) - case status == "M": - changes = append(changes, Change{Kind: KindModify, NewPath: parts[1], Status: status}) - case status == "D": - changes = append(changes, Change{Kind: KindDelete, NewPath: parts[1], Status: status}) - case strings.HasPrefix(status, "R") && len(parts) >= 3: - changes = append(changes, Change{Kind: KindRename, OldPath: parts[1], NewPath: parts[2], Status: status}) - default: - // ignore other statuses (C, T, U, etc) for this reference impl - } - } - if err := sc.Err(); err != nil { - return nil, err - } - return changes, nil -} - -// stagedLFSOID returns (oid, isLFS, err) based on STAGED content. -// isLFS is true only if the staged file is a valid LFS pointer with an oid sha256 line. -func stagedLFSOID(ctx context.Context, path string) (string, bool, error) { - out, err := git(ctx, "show", ":"+path) - if err != nil { - // path may not exist in index (deleted/intent-to-add weirdness) - return "", false, err - } - - // Fast parse: look for spec line and oid line near top. - // LFS pointer files are small; scanning full content is fine. - var hasSpec bool - var oid string - - sc := bufio.NewScanner(bytes.NewReader(out)) - for sc.Scan() { - line := sc.Text() - if line == lfsSpecLine { - hasSpec = true - continue - } - if strings.HasPrefix(line, "oid sha256:") { - hex := strings.TrimPrefix(line, "oid sha256:") - hex = strings.TrimSpace(hex) - if hex != "" { - oid = "sha256:" + hex - } - // keep scanning a bit in case spec is below (rare), but we can break once both are found. - } - // pointer usually has only a few lines; stop early after 10 lines - if hasSpec && oid != "" { - break - } - } - if err := sc.Err(); err != nil { - return "", false, err - } - - if hasSpec && oid != "" { - return oid, true, nil - } - return "", false, nil -} - -func gitRevParseGitDir(ctx context.Context) (string, error) { - out, err := git(ctx, "rev-parse", "--git-dir") - if err != nil { - return "", err - } - gitDir := strings.TrimSpace(string(out)) - if gitDir == "" { - return "", errors.New("could not determine .git dir") - } - // If gitDir is relative, resolve relative to repo root - if !filepath.IsAbs(gitDir) { - rootOut, err := git(ctx, "rev-parse", "--show-toplevel") - if err != nil { - return "", err - } - root := strings.TrimSpace(string(rootOut)) - gitDir = filepath.Join(root, gitDir) - } - return gitDir, nil -} - -func git(ctx context.Context, args ...string) ([]byte, error) { - cmd := exec.CommandContext(ctx, "git", args...) - cmd.Env = os.Environ() - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - if err := cmd.Run(); err != nil { - // include stderr for debugging; don’t leak massive output - msg := strings.TrimSpace(stderr.String()) - if msg == "" { - msg = err.Error() - } - return nil, fmt.Errorf("git %s: %s", strings.Join(args, " "), msg) - } - return stdout.Bytes(), nil -} - -// pathEntryFile maps a repo-relative path to a cache file location. -// We keep a deterministic encoding so any path maps to exactly one file. -func pathEntryFile(pathsDir, path string) string { - return filepath.Join(pathsDir, encodePath(path)+".json") -} - -func encodePath(path string) string { - // base64url encoding of the UTF-8 path string (no padding) is simple and safe. - return base64.RawURLEncoding.EncodeToString([]byte(path)) -} - -func oidEntryFile(oidsDir, oid string) string { - // OID contains ":"; make it filesystem safe but still human readable. - // Use a stable transform; here: sha256 of oid string to avoid path length issues. - sum := sha256.Sum256([]byte(oid)) - return filepath.Join(oidsDir, fmt.Sprintf("%x.json", sum[:])) -} - -// oidAddOrReplacePath: -// - loads oid entry (if exists) -// - adds newPath to paths[] -// - if oldPath != "" and present, replaces it with newPath -// - sets ContentChange flag if requested (ORed into existing flag) -// - preserves existing s3_url hint -func oidAddOrReplacePath(oidsDir, oid, oldPath, newPath, now string, contentChanged bool) error { - f := oidEntryFile(oidsDir, oid) - - entry := OIDEntry{ - LFSOID: oid, - Paths: []string{}, - UpdatedAt: now, - } - if b, err := os.ReadFile(f); err == nil { - _ = json.Unmarshal(b, &entry) - // ensure oid is set even if old file was incomplete - entry.LFSOID = oid - } - - paths := make(map[string]struct{}, len(entry.Paths)+1) - for _, p := range entry.Paths { - paths[p] = struct{}{} - } - - if oldPath != "" { - delete(paths, oldPath) - } - if newPath != "" { - paths[newPath] = struct{}{} - } - - entry.Paths = keysSorted(paths) - entry.UpdatedAt = now - entry.ContentChange = entry.ContentChange || contentChanged - - return writeJSONAtomic(f, entry) -} - -func oidRemovePath(oidsDir, oid, path, now string) error { - f := oidEntryFile(oidsDir, oid) - - b, err := os.ReadFile(f) - if err != nil { - return err - } - var entry OIDEntry - if err := json.Unmarshal(b, &entry); err != nil { - return err - } - paths := make(map[string]struct{}, len(entry.Paths)) - for _, p := range entry.Paths { - if p == path { - continue - } - paths[p] = struct{}{} - } - entry.Paths = keysSorted(paths) - entry.UpdatedAt = now - - // If no paths remain, keep the file (it may still hold s3_url hint) or delete it. - // This ADR allows stale entries; keeping is fine. Optionally delete when empty: - // if len(entry.Paths) == 0 && entry.S3URL == "" { return os.Remove(f) } - - return writeJSONAtomic(f, entry) -} - -func keysSorted(m map[string]struct{}) []string { - out := make([]string, 0, len(m)) - for k := range m { - out = append(out, k) - } - sort.Strings(out) - return out -} - -// writeJSONAtomic writes JSON to a temp file then renames it into place. -// This avoids partially written cache files if the process is interrupted. -func writeJSONAtomic(path string, v any) error { - dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0o755); err != nil { - return err - } - - tmp := path + ".tmp" - f, err := os.OpenFile(tmp, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) - if err != nil { - return err - } - defer func() { _ = f.Close() }() - - enc := json.NewEncoder(f) - enc.SetIndent("", " ") - if err := enc.Encode(v); err != nil { - _ = os.Remove(tmp) - return err - } - if err := f.Sync(); err != nil { - _ = os.Remove(tmp) - return err - } - if err := f.Close(); err != nil { - _ = os.Remove(tmp) - return err - } - return os.Rename(tmp, path) -} - -func moveFileBestEffort(src, dst string) error { - // Ensure destination directory exists. - if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { - return err - } - // Rename will fail across devices; fall back to copy+remove. - if err := os.Rename(src, dst); err == nil { - return nil - } else if errors.Is(err, os.ErrNotExist) { - return err - } - - in, err := os.Open(src) - if err != nil { - return err - } - defer in.Close() - - out, err := os.OpenFile(dst, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) - if err != nil { - return err - } - - if _, err := io.Copy(out, in); err != nil { - _ = out.Close() - return err - } - if err := out.Close(); err != nil { - return err - } - return os.Remove(src) -} diff --git a/cmd/precommit/main_test.go b/cmd/precommit/main_test.go index 8a0fb0c6..909b8566 100644 --- a/cmd/precommit/main_test.go +++ b/cmd/precommit/main_test.go @@ -9,6 +9,8 @@ import ( "strings" "testing" "time" + + "github.com/calypr/git-drs/internal/precommit_cache" ) func TestHandleUpsertIgnoresNonLFSFile(t *testing.T) { @@ -35,12 +37,13 @@ func TestHandleUpsertIgnoresNonLFSFile(t *testing.T) { t.Fatalf("mkdir oids: %v", err) } + cache := &precommit_cache.Cache{Root: cacheRoot, PathsDir: pathsDir, OIDsDir: oidsDir} now := time.Now().UTC().Format(time.RFC3339) - if err := handleUpsert(context.Background(), pathsDir, oidsDir, "data/file.txt", now); err != nil { + if err := handleUpsert(context.Background(), cache, "data/file.txt", now); err != nil { t.Fatalf("handleUpsert: %v", err) } - pathEntry := pathEntryFile(pathsDir, "data/file.txt") + pathEntry := precommit_cache.PathEntryPath(cache, "data/file.txt") if _, err := os.Stat(pathEntry); !os.IsNotExist(err) { t.Fatalf("expected no cache entry for non-LFS file, got err=%v", err) } @@ -76,17 +79,18 @@ func TestHandleUpsertWritesLFSPointerCache(t *testing.T) { t.Fatalf("mkdir oids: %v", err) } + cache := &precommit_cache.Cache{Root: cacheRoot, PathsDir: pathsDir, OIDsDir: oidsDir} now := time.Now().UTC().Format(time.RFC3339) - if err := handleUpsert(context.Background(), pathsDir, oidsDir, "data/file.bin", now); err != nil { + if err := handleUpsert(context.Background(), cache, "data/file.bin", now); err != nil { t.Fatalf("handleUpsert: %v", err) } - pathEntry := pathEntryFile(pathsDir, "data/file.bin") + pathEntry := precommit_cache.PathEntryPath(cache, "data/file.bin") pathData, err := os.ReadFile(pathEntry) if err != nil { t.Fatalf("read path entry: %v", err) } - var pathCache PathEntry + var pathCache precommit_cache.PathEntry if err := json.Unmarshal(pathData, &pathCache); err != nil { t.Fatalf("unmarshal path entry: %v", err) } @@ -97,12 +101,12 @@ func TestHandleUpsertWritesLFSPointerCache(t *testing.T) { t.Fatalf("expected lfs oid sha256:deadbeef, got %q", pathCache.LFSOID) } - oidEntry := oidEntryFile(oidsDir, "sha256:deadbeef") + oidEntry := precommit_cache.OIDEntryPath(cache, "sha256:deadbeef") oidData, err := os.ReadFile(oidEntry) if err != nil { t.Fatalf("read oid entry: %v", err) } - var oidCache OIDEntry + var oidCache precommit_cache.OIDEntry if err := json.Unmarshal(oidData, &oidCache); err != nil { t.Fatalf("unmarshal oid entry: %v", err) } @@ -114,6 +118,85 @@ func TestHandleUpsertWritesLFSPointerCache(t *testing.T) { } } +func TestCollectOversizedPlainGitStagedFiles(t *testing.T) { + repo := setupGitRepo(t) + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + plainPath := filepath.Join(repo, "data", "large.bin") + if err := os.MkdirAll(filepath.Dir(plainPath), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(plainPath, []byte("plain oversized payload"), 0o644); err != nil { + t.Fatalf("write plain file: %v", err) + } + gitCmd(t, repo, "add", "data/large.bin") + + pointerPath := filepath.Join(repo, "data", "pointer.bin") + lfsPointer := strings.Join([]string{ + "version https://git-lfs.github.com/spec/v1", + "oid sha256:deadbeef", + "size 999", + "", + }, "\n") + if err := os.WriteFile(pointerPath, []byte(lfsPointer), 0o644); err != nil { + t.Fatalf("write pointer file: %v", err) + } + gitCmd(t, repo, "add", "data/pointer.bin") + + changes, err := stagedChanges(context.Background()) + if err != nil { + t.Fatalf("stagedChanges: %v", err) + } + files, err := collectOversizedPlainGitStagedFiles(context.Background(), changes, 1) + if err != nil { + t.Fatalf("collectOversizedPlainGitStagedFiles: %v", err) + } + if len(files) != 1 { + t.Fatalf("expected 1 oversized plain file, got %d: %+v", len(files), files) + } + if files[0].Path != "data/large.bin" { + t.Fatalf("unexpected oversized file path: %+v", files[0]) + } +} + +func TestRunAbortsWhenOversizedPlainGitCommitIsRejected(t *testing.T) { + repo := setupGitRepo(t) + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + path := filepath.Join(repo, "data", "large.bin") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(path, []byte("plain oversized payload"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + gitCmd(t, repo, "add", "data/large.bin") + + oldThreshold := directCommitWarningThresholdBytes + oldPrompt := confirmOversizedDirectGitCommit + t.Cleanup(func() { + directCommitWarningThresholdBytes = oldThreshold + confirmOversizedDirectGitCommit = oldPrompt + }) + directCommitWarningThresholdBytes = 1 + confirmOversizedDirectGitCommit = func(files []OversizedStagedFile) (bool, error) { + if len(files) != 1 || files[0].Path != "data/large.bin" { + t.Fatalf("unexpected prompt files: %+v", files) + } + return false, nil + } + + err := run(context.Background()) + if err == nil { + t.Fatal("expected run to abort when oversized file warning is rejected") + } + if !strings.Contains(err.Error(), "commit aborted") { + t.Fatalf("unexpected error: %v", err) + } +} + func setupGitRepo(t *testing.T) string { t.Helper() dir := t.TempDir() diff --git a/cmd/precommit/run.go b/cmd/precommit/run.go new file mode 100644 index 00000000..1dcad053 --- /dev/null +++ b/cmd/precommit/run.go @@ -0,0 +1,124 @@ +// Package precommit updates the local `.git/drs/pre-commit` cache from staged +// pointer changes. The cache is rebuildable local bookkeeping, distinct from +// the authoritative local DRS metadata stored under `.git/drs/lfs/objects`. +package precommit + +import ( + "context" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/calypr/git-drs/internal/precommit_cache" +) + +type ChangeKind int + +const ( + KindAdd ChangeKind = iota + KindModify + KindDelete + KindRename +) + +type Change struct { + Kind ChangeKind + OldPath string + NewPath string + Status string +} + +func run(ctx context.Context) error { + if ctx == nil { + ctx = context.Background() + } + + cache, err := precommit_cache.Open(ctx) + if err != nil { + return err + } + if err := precommit_cache.EnsureLayout(cache); err != nil { + return err + } + tombsDir := filepath.Join(cache.Root, "tombstones") + if err := os.MkdirAll(tombsDir, 0o755); err != nil { + return fmt.Errorf("create tombstones directory: %w", err) + } + + changes, err := stagedChanges(ctx) + if err != nil { + return err + } + if len(changes) == 0 { + return nil + } + oversized, err := collectOversizedPlainGitStagedFiles(ctx, changes, directCommitWarningThresholdBytes) + if err != nil { + return err + } + if len(oversized) > 0 { + allowed, err := confirmOversizedDirectGitCommit(oversized) + if err != nil { + return err + } + if !allowed { + return fmt.Errorf("commit aborted so you can track large files before committing them directly to Git") + } + } + + now := time.Now().UTC().Format(time.RFC3339) + for _, ch := range changes { + if ch.Kind != KindRename { + continue + } + newOID, newIsLFS, err := stagedLFSOID(ctx, ch.NewPath) + if err != nil { + continue + } + + oldPathFile := precommit_cache.PathEntryPath(cache, ch.OldPath) + newPathFile := precommit_cache.PathEntryPath(cache, ch.NewPath) + + if newIsLFS { + if err := moveFileBestEffort(oldPathFile, newPathFile); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("rename migrate path entry: %w", err) + } + + if err := precommit_cache.WritePathEntry(cache, precommit_cache.PathEntry{ + Path: ch.NewPath, + LFSOID: newOID, + UpdatedAt: now, + }); err != nil { + return err + } + + if err := precommit_cache.UpsertOIDPath(cache, newOID, ch.OldPath, ch.NewPath, "", now, false); err != nil { + return err + } + } else { + if err := removeIfExists(oldPathFile); err != nil { + return fmt.Errorf("remove stale path entry for %s: %w", ch.OldPath, err) + } + } + } + + for _, ch := range changes { + switch ch.Kind { + case KindAdd, KindModify: + if err := handleUpsert(ctx, cache, ch.NewPath, now); err != nil { + return err + } + case KindRename: + if err := handleUpsert(ctx, cache, ch.NewPath, now); err != nil { + return err + } + case KindDelete: + if err := handleDelete(ctx, cache, tombsDir, ch.NewPath, now); err != nil { + return err + } + } + } + + return nil +} diff --git a/cmd/precommit/warnings.go b/cmd/precommit/warnings.go new file mode 100644 index 00000000..1088132f --- /dev/null +++ b/cmd/precommit/warnings.go @@ -0,0 +1,91 @@ +package precommit + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "os" + "sort" + "strings" +) + +type OversizedStagedFile struct { + Path string + Size int64 +} + +func collectOversizedPlainGitStagedFiles(ctx context.Context, changes []Change, thresholdBytes int64) ([]OversizedStagedFile, error) { + if thresholdBytes <= 0 { + return nil, nil + } + var oversized []OversizedStagedFile + seen := make(map[string]struct{}) + for _, ch := range changes { + if ch.Kind != KindAdd && ch.Kind != KindModify && ch.Kind != KindRename { + continue + } + path := ch.NewPath + if path == "" { + continue + } + if _, ok := seen[path]; ok { + continue + } + seen[path] = struct{}{} + + _, isLFS, err := stagedLFSOID(ctx, path) + if err != nil { + continue + } + if isLFS { + continue + } + + size, err := stagedBlobSize(ctx, path) + if err != nil { + return nil, err + } + if size <= thresholdBytes { + continue + } + oversized = append(oversized, OversizedStagedFile{Path: path, Size: size}) + } + sort.Slice(oversized, func(i, j int) bool { return oversized[i].Path < oversized[j].Path }) + return oversized, nil +} + +func promptOversizedDirectGitCommit(files []OversizedStagedFile) (bool, error) { + if len(files) == 0 { + return true, nil + } + + fmt.Fprintf(os.Stderr, "\nWarning: the following staged files are being committed directly to Git and exceed %s:\n\n", humanBytes(directCommitWarningThresholdBytes)) + for _, f := range files { + fmt.Fprintf(os.Stderr, " - %s (%s)\n", f.Path, humanBytes(f.Size)) + } + fmt.Fprintln(os.Stderr, "\nIf these should be managed by git-drs, track them first and re-add them.") + fmt.Fprint(os.Stderr, "Continue committing these files directly to GitHub? [y/N]: ") + + reader := bufio.NewReader(os.Stdin) + line, err := reader.ReadString('\n') + if err != nil && !errors.Is(err, io.EOF) { + return false, err + } + answer := strings.ToLower(strings.TrimSpace(line)) + return answer == "y" || answer == "yes", nil +} + +func humanBytes(n int64) string { + const unit = int64(1024) + if n < unit { + return fmt.Sprintf("%d B", n) + } + div, exp := unit, 0 + for q := n / unit; q >= unit; q /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(n)/float64(div), "KMGTPE"[exp]) +} diff --git a/cmd/prepush/main.go b/cmd/prepush/main.go deleted file mode 100644 index 9e2e841e..00000000 --- a/cmd/prepush/main.go +++ /dev/null @@ -1,605 +0,0 @@ -package prepush - -import ( - "bufio" - "bytes" - "context" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "log/slog" - "net/http" - "os" - "os/exec" - "sort" - "strings" - "time" - - "github.com/calypr/git-drs/internal/common" - "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/drslog" - "github.com/calypr/git-drs/internal/drsmap" - "github.com/calypr/git-drs/internal/drsobject" - "github.com/calypr/git-drs/internal/gitrepo" - "github.com/calypr/git-drs/internal/lfs" - "github.com/calypr/git-drs/internal/precommit_cache" - drsapi "github.com/calypr/syfon/apigen/client/drs" - syfoncommon "github.com/calypr/syfon/common" - "github.com/spf13/cobra" -) - -// Cmd line declaration -var Cmd = &cobra.Command{ - Use: "pre-push-prepare", - Short: "pre-push hook to update DRS objects", - Long: "Pre-push hook that updates DRS objects before transfer", - Args: cobra.RangeArgs(0, 2), - RunE: func(cmd *cobra.Command, args []string) error { - return NewPrePushService().Run(args, os.Stdin) - }, -} - -type PrePushService struct { - newLogger func(string, bool) (*slog.Logger, error) - loadConfig func() (*config.Config, error) - writeDrsObjects func(drsobject.Builder, map[string]lfs.LfsFileInfo, drsmap.WriteOptions) error - createTempFile func(dir, pattern string) (*os.File, error) -} - -func NewPrePushService() *PrePushService { - return &PrePushService{ - newLogger: drslog.NewLogger, - loadConfig: config.LoadConfig, - writeDrsObjects: drsmap.WriteObjectsForLFSFiles, - createTempFile: os.CreateTemp, - } -} - -func (s *PrePushService) Run(args []string, stdin io.Reader) error { - ctx := context.Background() - myLogger, err := s.newLogger("", false) - if err != nil { - return fmt.Errorf("error creating logger: %v", err) - } - - myLogger.Info("~~~~~~~~~~~~~ START: pre-push ~~~~~~~~~~~~~") - - cfg, err := s.loadConfig() - if err != nil { - return fmt.Errorf("error getting config: %v", err) - } - - gitRemoteName, gitRemoteLocation := parseRemoteArgs(args) - myLogger.Debug(fmt.Sprintf("git remote name: %s, git remote location: %s", gitRemoteName, gitRemoteLocation)) - - remote, err := cfg.GetDefaultRemote() - if err != nil { - myLogger.Debug(fmt.Sprintf("Warning. Error getting default remote: %v", err)) - fmt.Fprintln(os.Stderr, "Warning. Skipping DRS preparation. Error getting default remote:", err) - return nil - } - - remoteConfig := cfg.GetRemote(remote) - if remoteConfig == nil { - fmt.Fprintln(os.Stderr, "Warning. Skipping DRS preparation. Error getting remote configuration.") - myLogger.Debug("Warning. Skipping DRS preparation. Error getting remote configuration.") - return nil - } - - scope, err := gitrepo.ResolveBucketScope( - remoteConfig.GetOrganization(), - remoteConfig.GetProjectId(), - remoteConfig.GetBucketName(), - remoteConfig.GetStoragePrefix(), - ) - if err != nil { - return err - } - - builder := drsobject.NewBuilder(scope.Bucket, remoteConfig.GetProjectId()) - builder.Organization = remoteConfig.GetOrganization() - builder.StoragePrefix = scope.Prefix - myLogger.Debug(fmt.Sprintf("Current server project: %s (org: %s)", builder.Project, builder.Organization)) - - tmp, err := bufferStdin(stdin, s.createTempFile) - if err != nil { - myLogger.Error(fmt.Sprintf("error buffering stdin: %v", err)) - return err - } - defer func() { - _ = tmp.Close() - _ = os.Remove(tmp.Name()) - }() - - refs, err := readPushedRefs(tmp) - if err != nil { - myLogger.Error(fmt.Sprintf("error reading pushed refs: %v", err)) - return err - } - branches := branchesFromRefs(refs) - - cache, cacheReady := openCache(ctx, myLogger) - lfsFiles, usedCache, err := collectLfsFiles(ctx, cache, cacheReady, gitRemoteName, gitRemoteLocation, branches, refs, myLogger) - if err != nil { - myLogger.Error(fmt.Sprintf("error collecting LFS files: %v", err)) - return err - } - - myLogger.Debug(fmt.Sprintf("Preparing DRS objects for push branches: %v (cache=%v)", branches, usedCache)) - err = s.writeDrsObjects(builder, lfsFiles, drsmap.WriteOptions{ - Cache: cache, - PreferCacheURL: usedCache, - Logger: myLogger, - }) - if err != nil { - myLogger.Error(fmt.Sprintf("WriteObjectsForLFSFiles failed: %v", err)) - return err - } - - // Stage metadata in one packet; server consumes it at LFS verify-time. - myLogger.Info(fmt.Sprintf("Staging %d DRS metadata records for LFS verify", len(lfsFiles))) - if err := submitPendingLFSMeta(ctx, remote, remoteConfig.GetEndpoint(), lfsFiles, myLogger); err != nil { - myLogger.Error(fmt.Sprintf("DRS metadata staging failed: %v", err)) - return fmt.Errorf("DRS metadata staging failed: %w", err) - } - - myLogger.Info("~~~~~~~~~~~~~ COMPLETED: pre-push ~~~~~~~~~~~~~") - return nil -} - -type metadataSubmitRequest struct { - Candidates []metadataCandidate `json:"candidates"` - TTLSeconds int64 `json:"ttl_seconds,omitempty"` -} - -type metadataChecksum struct { - Type string `json:"type"` - Checksum string `json:"checksum"` -} - -type metadataAccessURL struct { - URL string `json:"url,omitempty"` -} - -type metadataAccessMethod struct { - Type string `json:"type,omitempty"` - AccessURL metadataAccessURL `json:"access_url,omitempty"` - AccessID string `json:"access_id,omitempty"` - Region string `json:"region,omitempty"` - Authorizations map[string][]string `json:"authorizations,omitempty"` -} - -type metadataCandidate struct { - Id string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Size int64 `json:"size"` - Version string `json:"version,omitempty"` - MimeType string `json:"mime_type,omitempty"` - Checksums []metadataChecksum `json:"checksums"` - AccessMethods []metadataAccessMethod `json:"access_methods,omitempty"` - Description string `json:"description,omitempty"` - Aliases []string `json:"aliases,omitempty"` -} - -func toMetadataCandidate(c drsapi.DrsObjectCandidate) metadataCandidate { - name := "" - if c.Name != nil { - name = *c.Name - } - mimeType := "" - if c.MimeType != nil { - mimeType = *c.MimeType - } - description := "" - if c.Description != nil { - description = *c.Description - } - aliases := []string(nil) - if c.Aliases != nil { - aliases = append([]string(nil), (*c.Aliases)...) - } - out := metadataCandidate{ - Name: name, - Size: c.Size, - Version: "", - MimeType: mimeType, - Description: description, - Aliases: aliases, - } - if c.Version != nil { - out.Version = *c.Version - } - - if len(c.Checksums) > 0 { - out.Checksums = make([]metadataChecksum, 0, len(c.Checksums)) - for _, cs := range c.Checksums { - out.Checksums = append(out.Checksums, metadataChecksum{ - Type: string(cs.Type), - Checksum: cs.Checksum, - }) - } - } - - if c.AccessMethods != nil && len(*c.AccessMethods) > 0 { - out.AccessMethods = make([]metadataAccessMethod, 0, len(*c.AccessMethods)) - for _, am := range *c.AccessMethods { - accID := "" - if am.AccessId != nil { - accID = *am.AccessId - } - region := "" - if am.Region != nil { - region = *am.Region - } - accURL := "" - if am.AccessUrl != nil { - accURL = am.AccessUrl.Url - } - m := metadataAccessMethod{ - Type: string(am.Type), - AccessID: accID, - Region: region, - AccessURL: metadataAccessURL{ - URL: accURL, - }, - } - if authzMap := syfoncommon.AuthzMapFromAccessMethodAuthorizations(am.Authorizations); len(authzMap) > 0 { - m.Authorizations = authzMap - } - out.AccessMethods = append(out.AccessMethods, m) - } - } - - return out -} - -func submitPendingLFSMeta(ctx context.Context, remote config.Remote, endpoint string, lfsFiles map[string]lfs.LfsFileInfo, logger *slog.Logger) error { - base := strings.TrimRight(strings.TrimSpace(endpoint), "/") - if base == "" { - return fmt.Errorf("remote endpoint is empty") - } - url := base + "/info/lfs/objects/metadata" - - candidates := make([]metadataCandidate, 0, len(lfsFiles)) - for _, file := range lfsFiles { - obj, err := drsobject.ReadObject(common.DRS_OBJS_PATH, file.Oid) - if err != nil || obj == nil { - logger.Debug(fmt.Sprintf("skipping oid %s: local DRS object not found", file.Oid)) - continue - } - candidates = append(candidates, toMetadataCandidate(drsobject.ConvertToCandidate(obj))) - } - if len(candidates) == 0 { - logger.Debug("no metadata candidates to stage") - return nil - } - - reqBody := metadataSubmitRequest{ - Candidates: candidates, - TTLSeconds: int64((20 * time.Minute).Seconds()), - } - payload, err := json.Marshal(reqBody) - if err != nil { - return fmt.Errorf("failed to encode pending metadata request: %w", err) - } - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) - if err != nil { - return fmt.Errorf("failed to create pending metadata request: %w", err) - } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Accept", "application/vnd.git-lfs+json") - if authHeader, ok := resolveRemoteAuthHeader(string(remote)); ok { - httpReq.Header.Set("Authorization", authHeader) - } - - client := pendingMetadataClientFactory() - resp, err := client.Do(httpReq) - if err != nil { - return fmt.Errorf("pending metadata request failed: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - bodyText := strings.TrimSpace(string(body)) - // Some deployments do not yet expose /info/lfs/objects/metadata. - // Treat this as optional capability and continue with push flow. - switch resp.StatusCode { - case http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusNotImplemented: - logger.Warn(fmt.Sprintf("metadata staging endpoint unavailable (status=%d); continuing without staged metadata", resp.StatusCode)) - return nil - } - // Some reverse proxies/frontends may return HTML 404/maintenance pages with 5xx. - // If this looks like non-API HTML and not a structured LFS error, degrade gracefully. - if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/html") { - logger.Warn(fmt.Sprintf("metadata staging returned HTML response (status=%d); continuing without staged metadata", resp.StatusCode)) - return nil - } - return fmt.Errorf("pending metadata request failed: status=%d body=%s", resp.StatusCode, bodyText) - } - return nil -} - -func resolveRemoteAuthHeader(remoteName string) (string, bool) { - if token, err := gitrepo.GetRemoteToken(remoteName); err == nil { - if token = strings.TrimSpace(token); token != "" { - return "Bearer " + token, true - } - } - username, password, err := gitrepo.GetRemoteBasicAuth(remoteName) - if err != nil || username == "" || password == "" { - return "", false - } - creds := username + ":" + password - return "Basic " + base64.StdEncoding.EncodeToString([]byte(creds)), true -} - -func parseRemoteArgs(args []string) (string, string) { - var gitRemoteName, gitRemoteLocation string - if len(args) >= 1 { - gitRemoteName = args[0] - } - if len(args) >= 2 { - gitRemoteLocation = args[1] - } - if gitRemoteName == "" { - gitRemoteName = "origin" - } - return gitRemoteName, gitRemoteLocation -} - -type pushedRef struct { - LocalRef string - LocalSHA string - RemoteRef string - RemoteSHA string -} - -func bufferStdin(stdin io.Reader, createTempFile func(dir, pattern string) (*os.File, error)) (*os.File, error) { - tmp, err := createTempFile("", "prepush-stdin-*") - if err != nil { - return nil, fmt.Errorf("error creating temp file for stdin: %w", err) - } - - if _, err := io.Copy(tmp, stdin); err != nil { - _ = tmp.Close() - _ = os.Remove(tmp.Name()) - return nil, fmt.Errorf("error buffering stdin: %w", err) - } - - if _, err := tmp.Seek(0, 0); err != nil { - _ = tmp.Close() - _ = os.Remove(tmp.Name()) - return nil, fmt.Errorf("error seeking temp stdin: %w", err) - } - return tmp, nil -} - -// readPushedBranches reads git push lines from the provided temp file, -// extracts unique local branch names for refs under `refs/heads/` and -// returns them sorted. The file is rewound to the start before returning. -func readPushedRefs(f io.ReadSeeker) ([]pushedRef, error) { - // Ensure we read from start - // example: - // refs/heads/main 67890abcdef1234567890abcdef1234567890abcd refs/heads/main 12345abcdef67890abcdef1234567890abcdef12 - if _, err := f.Seek(0, 0); err != nil { - return nil, err - } - scanner := bufio.NewScanner(f) - refs := make([]pushedRef, 0) - for scanner.Scan() { - line := scanner.Text() - fields := strings.Fields(line) - if len(fields) < 4 { - continue - } - refs = append(refs, pushedRef{ - LocalRef: fields[0], - LocalSHA: fields[1], - RemoteRef: fields[2], - RemoteSHA: fields[3], - }) - } - if err := scanner.Err(); err != nil { - return nil, err - } - // Rewind so caller can reuse the file - if _, err := f.Seek(0, 0); err != nil { - return nil, err - } - return refs, nil -} - -func branchesFromRefs(refs []pushedRef) []string { - const prefix = "refs/heads/" - set := make(map[string]struct{}) - for _, ref := range refs { - if strings.HasPrefix(ref.LocalRef, prefix) { - branch := strings.TrimPrefix(ref.LocalRef, prefix) - if branch != "" { - set[branch] = struct{}{} - } - } - } - branches := make([]string, 0, len(set)) - for b := range set { - branches = append(branches, b) - } - sort.Strings(branches) - return branches -} - -func openCache(ctx context.Context, logger *slog.Logger) (*precommit_cache.Cache, bool) { - cache, err := precommit_cache.Open(ctx) - if err != nil { - logger.Debug(fmt.Sprintf("pre-commit cache unavailable: %v", err)) - return nil, false - } - if _, err := os.Stat(cache.Root); err != nil { - if os.IsNotExist(err) { - logger.Debug("pre-commit cache missing; continuing without cache") - } else { - logger.Debug(fmt.Sprintf("pre-commit cache access error: %v", err)) - } - return nil, false - } - return cache, true -} - -func collectLfsFiles(ctx context.Context, cache *precommit_cache.Cache, cacheReady bool, gitRemoteName, gitRemoteLocation string, branches []string, refs []pushedRef, logger *slog.Logger) (map[string]lfs.LfsFileInfo, bool, error) { - if cacheReady { - lfsFiles, ok, err := lfsFilesFromCache(ctx, cache, refs, logger) - if err != nil { - logger.Debug(fmt.Sprintf("pre-commit cache read failed: %v", err)) - } else if ok { - return lfsFiles, true, nil - } - logger.Debug("pre-commit cache incomplete or stale; falling back to LFS discovery") - } - lfsFiles, err := lfs.GetAllLfsFiles(gitRemoteName, gitRemoteLocation, branches, logger) - if err != nil { - return nil, false, err - } - return lfsFiles, false, nil -} - -const cacheMaxAge = 24 * time.Hour - -var pendingMetadataClientFactory = func() *http.Client { - return &http.Client{Timeout: 20 * time.Second} -} - -func normalizeCachedOID(oid string) string { - normalized := strings.TrimSpace(oid) - if len(normalized) >= len("sha256:") && strings.EqualFold(normalized[:len("sha256:")], "sha256:") { - normalized = normalized[len("sha256:"):] - } - return strings.TrimSpace(normalized) -} - -func lfsFilesFromCache(ctx context.Context, cache *precommit_cache.Cache, refs []pushedRef, logger *slog.Logger) (map[string]lfs.LfsFileInfo, bool, error) { - if cache == nil { - return nil, false, nil - } - paths, err := listPushedPaths(ctx, refs) - if err != nil { - return nil, false, err - } - lfsFiles := make(map[string]lfs.LfsFileInfo, len(paths)) - for _, path := range paths { - entry, ok, err := cache.ReadPathEntry(path) - if err != nil { - return nil, false, err - } - if !ok { - return nil, false, nil - } - oid := normalizeCachedOID(entry.LFSOID) - if oid == "" { - return nil, false, nil - } - if entry.UpdatedAt == "" || precommit_cache.StaleAfter(entry.UpdatedAt, cacheMaxAge) { - return nil, false, nil - } - stat, err := os.Stat(path) - if err != nil { - logger.Debug(fmt.Sprintf("cache path stat failed for %s: %v", path, err)) - return nil, false, nil - } - lfsFiles[path] = lfs.LfsFileInfo{ - Name: path, - Size: stat.Size(), - OidType: "sha256", - Oid: oid, - Version: "https://git-lfs.github.com/spec/v1", - } - } - return lfsFiles, true, nil -} - -func listPushedPaths(ctx context.Context, refs []pushedRef) ([]string, error) { - const zeroSHA = "0000000000000000000000000000000000000000" - set := make(map[string]struct{}) - for _, ref := range refs { - if ref.LocalSHA == "" || ref.LocalSHA == zeroSHA { - continue - } - var args []string - if ref.RemoteSHA == "" || ref.RemoteSHA == zeroSHA { - args = []string{"ls-tree", "-r", "--name-only", ref.LocalSHA} - } else { - args = []string{"diff", "--name-only", ref.RemoteSHA, ref.LocalSHA} - } - out, err := gitOutput(ctx, args...) - if err != nil { - return nil, err - } - for _, line := range strings.Split(strings.TrimSpace(out), "\n") { - line = strings.TrimSpace(line) - if line == "" { - continue - } - set[line] = struct{}{} - } - } - paths := make([]string, 0, len(set)) - for path := range set { - paths = append(paths, path) - } - sort.Strings(paths) - return paths, nil -} - -func gitOutput(ctx context.Context, args ...string) (string, error) { - cmd := exec.CommandContext(ctx, "git", args...) - cmd.Env = os.Environ() - out, err := cmd.CombinedOutput() - if err != nil { - return "", fmt.Errorf("git %s: %s", strings.Join(args, " "), strings.TrimSpace(string(out))) - } - return string(out), nil -} - -// readPushedBranches reads git push lines from the provided temp file, -// extracts unique local branch names for refs under `refs/heads/` and -// returns them sorted. The file is rewound to the start before returning. -func readPushedBranches(f *os.File) ([]string, error) { - // Ensure we read from start - // example: - // refs/heads/main 67890abcdef1234567890abcdef1234567890abcd refs/heads/main 12345abcdef67890abcdef1234567890abcdef12 - if _, err := f.Seek(0, 0); err != nil { - return nil, err - } - scanner := bufio.NewScanner(f) - set := make(map[string]struct{}) - for scanner.Scan() { - line := scanner.Text() - fields := strings.Fields(line) - if len(fields) < 1 { - continue - } - localRef := fields[0] - const prefix = "refs/heads/" - if strings.HasPrefix(localRef, prefix) { - branch := strings.TrimPrefix(localRef, prefix) - if branch != "" { - set[branch] = struct{}{} - } - } - } - if err := scanner.Err(); err != nil { - return nil, err - } - branches := make([]string, 0, len(set)) - for b := range set { - branches = append(branches, b) - } - sort.Strings(branches) - // Rewind so caller can reuse the file - if _, err := f.Seek(0, 0); err != nil { - return nil, err - } - return branches, nil -} diff --git a/cmd/prepush/main_test.go b/cmd/prepush/main_test.go deleted file mode 100644 index 5fad785b..00000000 --- a/cmd/prepush/main_test.go +++ /dev/null @@ -1,622 +0,0 @@ -package prepush - -import ( - "context" - "encoding/base64" - "encoding/json" - "io" - "log/slog" - "net/http" - "os" - "os/exec" - "path/filepath" - "strings" - "testing" - "time" - - "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/drsobject" - "github.com/calypr/git-drs/internal/lfs" - "github.com/calypr/git-drs/internal/precommit_cache" - drsapi "github.com/calypr/syfon/apigen/client/drs" -) - -func TestLfsFilesFromCache(t *testing.T) { - repo := setupGitRepo(t) - filePath := filepath.Join(repo, "data", "file.bin") - if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { - t.Fatalf("mkdir: %v", err) - } - if err := os.WriteFile(filePath, []byte("first"), 0o644); err != nil { - t.Fatalf("write: %v", err) - } - gitCmd(t, repo, "add", "data/file.bin") - gitCmd(t, repo, "commit", "-m", "first") - oldSHA := gitOutputString(t, repo, "rev-parse", "HEAD") - - if err := os.WriteFile(filePath, []byte("second"), 0o644); err != nil { - t.Fatalf("write: %v", err) - } - gitCmd(t, repo, "add", "data/file.bin") - gitCmd(t, repo, "commit", "-m", "second") - newSHA := gitOutputString(t, repo, "rev-parse", "HEAD") - - cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") - cache := &precommit_cache.Cache{ - GitDir: filepath.Join(repo, ".git"), - Root: cacheRoot, - PathsDir: filepath.Join(cacheRoot, "paths"), - OIDsDir: filepath.Join(cacheRoot, "oids"), - StatePath: filepath.Join(cacheRoot, "state.json"), - } - if err := os.MkdirAll(cache.PathsDir, 0o755); err != nil { - t.Fatalf("mkdir paths dir: %v", err) - } - if err := os.MkdirAll(cache.OIDsDir, 0o755); err != nil { - t.Fatalf("mkdir oids dir: %v", err) - } - pathEntry := precommit_cache.PathEntry{ - Path: "data/file.bin", - LFSOID: "oid-123", - UpdatedAt: time.Now().UTC().Format(time.RFC3339), - } - pathEntryFile := filepath.Join(cache.PathsDir, precommit_cache.EncodePath(pathEntry.Path)+".json") - writeJSON(t, pathEntryFile, pathEntry) - - oldwd := mustChdir(t, repo) - t.Cleanup(func() { _ = os.Chdir(oldwd) }) - - logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - refs := []pushedRef{{ - LocalRef: "refs/heads/main", - LocalSHA: newSHA, - RemoteRef: "refs/heads/main", - RemoteSHA: oldSHA, - }} - - lfsFiles, ok, err := lfsFilesFromCache(context.Background(), cache, refs, logger) - if err != nil { - t.Fatalf("lfsFilesFromCache: %v", err) - } - if !ok { - t.Fatalf("expected cache to be usable") - } - info, exists := lfsFiles["data/file.bin"] - if !exists { - t.Fatalf("expected lfs info for data/file.bin") - } - if info.Oid != "oid-123" { - t.Fatalf("expected oid to be oid-123, got %s", info.Oid) - } - if info.OidType != "sha256" { - t.Fatalf("expected oid type sha256, got %s", info.OidType) - } - stat, err := os.Stat(filePath) - if err != nil { - t.Fatalf("stat: %v", err) - } - if info.Size != stat.Size() { - t.Fatalf("expected size %d, got %d", stat.Size(), info.Size) - } -} - -func TestReadPushedBranches(t *testing.T) { - tests := []struct { - name string - input string - expected []string // Sorted - }{ - { - name: "single branch", - input: "refs/heads/main 1234 oid123 refs/heads/main 1234 oid456", - expected: []string{"main"}, - }, - { - name: "multiple branches", - input: "refs/heads/main 123 oid refs/heads/main 456 oid\nrefs/heads/feature/foo 789 oid remote 000 oid", - expected: []string{"feature/foo", "main"}, - }, - { - name: "ignore tags", - input: "refs/tags/v1.0 123 oid refs/tags/v1.0 123 oid", - expected: []string{}, - }, - { - name: "empty input", - input: "", - expected: []string{}, - }, - { - name: "malformed lines", - input: "just-garbage\nrefs/heads/ok 1 2 3", - expected: []string{"ok"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tmp, err := os.CreateTemp("", "test-stdin") - if err != nil { - t.Fatalf("create temp: %v", err) - } - defer os.Remove(tmp.Name()) - - if _, err := tmp.WriteString(tt.input); err != nil { - t.Fatalf("write temp: %v", err) - } - - // readPushedBranches seeks to 0 itself, but we pass the *os.File - // which must be valid. - branches, err := readPushedBranches(tmp) - if err != nil { - t.Fatalf("readPushedBranches error: %v", err) - } - - if len(branches) != len(tt.expected) { - t.Errorf("expected %d branches, got %d: %v", len(tt.expected), len(branches), branches) - return - } - for i := range branches { - if branches[i] != tt.expected[i] { - t.Errorf("branch mismatch at %d: got %s, want %s", i, branches[i], tt.expected[i]) - } - } - - tmp.Close() - }) - } -} - -func TestLfsFilesFromCacheStale(t *testing.T) { - repo := setupGitRepo(t) - filePath := filepath.Join(repo, "data", "file.bin") - if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { - t.Fatalf("mkdir: %v", err) - } - if err := os.WriteFile(filePath, []byte("data"), 0o644); err != nil { - t.Fatalf("write: %v", err) - } - gitCmd(t, repo, "add", "data/file.bin") - gitCmd(t, repo, "commit", "-m", "first") - sha := gitOutputString(t, repo, "rev-parse", "HEAD") - - cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") - cache := &precommit_cache.Cache{ - GitDir: filepath.Join(repo, ".git"), - Root: cacheRoot, - PathsDir: filepath.Join(cacheRoot, "paths"), - OIDsDir: filepath.Join(cacheRoot, "oids"), - StatePath: filepath.Join(cacheRoot, "state.json"), - } - if err := os.MkdirAll(cache.PathsDir, 0o755); err != nil { - t.Fatalf("mkdir paths dir: %v", err) - } - - pathEntry := precommit_cache.PathEntry{ - Path: "data/file.bin", - LFSOID: "oid-123", - UpdatedAt: time.Now().Add(-48 * time.Hour).UTC().Format(time.RFC3339), - } - pathEntryFile := filepath.Join(cache.PathsDir, precommit_cache.EncodePath(pathEntry.Path)+".json") - writeJSON(t, pathEntryFile, pathEntry) - - oldwd := mustChdir(t, repo) - t.Cleanup(func() { _ = os.Chdir(oldwd) }) - - logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - refs := []pushedRef{{ - LocalRef: "refs/heads/main", - LocalSHA: sha, - RemoteRef: "refs/heads/main", - RemoteSHA: "0000000000000000000000000000000000000000", - }} - - _, ok, err := lfsFilesFromCache(context.Background(), cache, refs, logger) - if err != nil { - t.Fatalf("lfsFilesFromCache: %v", err) - } - if ok { - t.Fatalf("expected cache to be stale") - } -} - -func TestLfsFilesFromCacheNormalizesOID(t *testing.T) { - repo := setupGitRepo(t) - filePath := filepath.Join(repo, "data", "file.bin") - if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { - t.Fatalf("mkdir: %v", err) - } - if err := os.WriteFile(filePath, []byte("first"), 0o644); err != nil { - t.Fatalf("write: %v", err) - } - gitCmd(t, repo, "add", "data/file.bin") - gitCmd(t, repo, "commit", "-m", "first") - oldSHA := gitOutputString(t, repo, "rev-parse", "HEAD") - - if err := os.WriteFile(filePath, []byte("second"), 0o644); err != nil { - t.Fatalf("write: %v", err) - } - gitCmd(t, repo, "add", "data/file.bin") - gitCmd(t, repo, "commit", "-m", "second") - newSHA := gitOutputString(t, repo, "rev-parse", "HEAD") - - cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") - cache := &precommit_cache.Cache{ - GitDir: filepath.Join(repo, ".git"), - Root: cacheRoot, - PathsDir: filepath.Join(cacheRoot, "paths"), - OIDsDir: filepath.Join(cacheRoot, "oids"), - StatePath: filepath.Join(cacheRoot, "state.json"), - } - if err := os.MkdirAll(cache.PathsDir, 0o755); err != nil { - t.Fatalf("mkdir paths dir: %v", err) - } - - rawOID := strings.Repeat("a", 64) - pathEntry := precommit_cache.PathEntry{ - Path: "data/file.bin", - LFSOID: " sha256:" + rawOID + " ", - UpdatedAt: time.Now().UTC().Format(time.RFC3339), - } - pathEntryFile := filepath.Join(cache.PathsDir, precommit_cache.EncodePath(pathEntry.Path)+".json") - writeJSON(t, pathEntryFile, pathEntry) - - oldwd := mustChdir(t, repo) - t.Cleanup(func() { _ = os.Chdir(oldwd) }) - - logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - refs := []pushedRef{{ - LocalRef: "refs/heads/main", - LocalSHA: newSHA, - RemoteRef: "refs/heads/main", - RemoteSHA: oldSHA, - }} - - lfsFiles, ok, err := lfsFilesFromCache(context.Background(), cache, refs, logger) - if err != nil { - t.Fatalf("lfsFilesFromCache: %v", err) - } - if !ok { - t.Fatalf("expected cache to be usable") - } - info, exists := lfsFiles["data/file.bin"] - if !exists { - t.Fatalf("expected lfs info for data/file.bin") - } - if info.Oid != rawOID { - t.Fatalf("expected normalized oid %q, got %q", rawOID, info.Oid) - } -} - -func TestBufferStdinCleansUpTempFileOnCopyError(t *testing.T) { - tmpDir := t.TempDir() - tmpPath := "" - _, err := bufferStdin(errReader{}, func(dir, pattern string) (*os.File, error) { - f, createErr := os.CreateTemp(tmpDir, pattern) - if createErr != nil { - return nil, createErr - } - tmpPath = f.Name() - return f, nil - }) - if err == nil { - t.Fatalf("expected bufferStdin error") - } - if _, statErr := os.Stat(tmpPath); !os.IsNotExist(statErr) { - t.Fatalf("expected temp file to be removed, stat err=%v", statErr) - } -} - -func TestSubmitPendingLFSMetaRequestWiring(t *testing.T) { - repo := setupGitRepo(t) - oldwd := mustChdir(t, repo) - t.Cleanup(func() { _ = os.Chdir(oldwd) }) - gitCmd(t, repo, "config", "drs.remote.origin.token", "test-token") - - oid := strings.Repeat("b", 64) - name := "obj-name" - if err := drsobject.WriteObject(".git/drs/lfs/objects", &drsapi.DrsObject{ - Id: "drs://local:obj-id", - Name: ptrString(name), - Size: 123, - Checksums: []drsapi.Checksum{ - {Type: "sha256", Checksum: oid}, - }, - }, oid); err != nil { - t.Fatalf("write drs object: %v", err) - } - - var gotPath, gotAuth, gotContentType, gotAccept string - var gotBody metadataSubmitRequest - restoreClient := stubPendingMetadataClient(t, func(req *http.Request) (*http.Response, error) { - gotPath = req.URL.Path - gotAuth = req.Header.Get("Authorization") - gotContentType = req.Header.Get("Content-Type") - gotAccept = req.Header.Get("Accept") - defer req.Body.Close() - if err := json.NewDecoder(req.Body).Decode(&gotBody); err != nil { - t.Fatalf("decode request body: %v", err) - } - return &http.Response{ - StatusCode: http.StatusOK, - Header: make(http.Header), - Body: io.NopCloser(strings.NewReader("{}")), - Request: req, - }, nil - }) - t.Cleanup(restoreClient) - - logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - err := submitPendingLFSMeta( - context.Background(), - config.Remote("origin"), - "https://example.test/ ", - map[string]lfs.LfsFileInfo{"file.bin": {Oid: oid}}, - logger, - ) - if err != nil { - t.Fatalf("submitPendingLFSMeta: %v", err) - } - if gotPath != "/info/lfs/objects/metadata" { - t.Fatalf("expected metadata endpoint path, got %q", gotPath) - } - if gotAuth != "Bearer test-token" { - t.Fatalf("expected auth header, got %q", gotAuth) - } - if gotContentType != "application/json" { - t.Fatalf("expected content-type application/json, got %q", gotContentType) - } - if gotAccept != "application/vnd.git-lfs+json" { - t.Fatalf("expected accept header application/vnd.git-lfs+json, got %q", gotAccept) - } - if len(gotBody.Candidates) != 1 { - t.Fatalf("expected 1 candidate, got %d", len(gotBody.Candidates)) - } - if len(gotBody.Candidates[0].Checksums) == 0 { - t.Fatalf("expected candidate checksums to be populated") - } -} - -func TestSubmitPendingLFSMetaStatusHandling(t *testing.T) { - repo := setupGitRepo(t) - oldwd := mustChdir(t, repo) - t.Cleanup(func() { _ = os.Chdir(oldwd) }) - - oid := strings.Repeat("c", 64) - name := "obj-name" - if err := drsobject.WriteObject(".git/drs/lfs/objects", &drsapi.DrsObject{ - Id: "drs://local:obj-id", - Name: ptrString(name), - Size: 123, - Checksums: []drsapi.Checksum{ - {Type: "sha256", Checksum: oid}, - }, - }, oid); err != nil { - t.Fatalf("write drs object: %v", err) - } - - tests := []struct { - name string - status int - contentType string - body string - wantErr bool - }{ - {name: "ok", status: http.StatusOK, contentType: "application/json", body: "{}", wantErr: false}, - {name: "degrade 404", status: http.StatusNotFound, contentType: "application/json", body: "{}", wantErr: false}, - {name: "degrade 405", status: http.StatusMethodNotAllowed, contentType: "application/json", body: "{}", wantErr: false}, - {name: "degrade 501", status: http.StatusNotImplemented, contentType: "application/json", body: "{}", wantErr: false}, - {name: "degrade html", status: http.StatusInternalServerError, contentType: "text/html; charset=utf-8", body: "error", wantErr: false}, - {name: "hard fail 401", status: http.StatusUnauthorized, contentType: "application/json", body: "{\"error\":\"unauthorized\"}", wantErr: true}, - {name: "hard fail 500", status: http.StatusInternalServerError, contentType: "application/json", body: "{\"error\":\"server\"}", wantErr: true}, - } - - logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - restoreClient := stubPendingMetadataClient(t, func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: tc.status, - Header: http.Header{"Content-Type": []string{tc.contentType}}, - Body: io.NopCloser(strings.NewReader(tc.body)), - Request: req, - }, nil - }) - t.Cleanup(restoreClient) - - err := submitPendingLFSMeta( - context.Background(), - config.Remote("origin"), - "https://example.test", - map[string]lfs.LfsFileInfo{"file.bin": {Oid: oid}}, - logger, - ) - if tc.wantErr && err == nil { - t.Fatalf("expected error, got nil") - } - if !tc.wantErr && err != nil { - t.Fatalf("expected nil error, got %v", err) - } - }) - } -} - -func TestResolveRemoteAuthHeaderBasicAuth(t *testing.T) { - repo := setupGitRepo(t) - oldwd := mustChdir(t, repo) - t.Cleanup(func() { _ = os.Chdir(oldwd) }) - gitCmd(t, repo, "config", "drs.remote.origin.username", "alice") - gitCmd(t, repo, "config", "drs.remote.origin.password", "secret") - - header, ok := resolveRemoteAuthHeader("origin") - if !ok { - t.Fatalf("expected auth header") - } - want := "Basic " + base64.StdEncoding.EncodeToString([]byte("alice:secret")) - if header != want { - t.Fatalf("expected %q, got %q", want, header) - } -} - -func TestResolveRemoteAuthHeaderPrefersBearerTokenOverBasicAuth(t *testing.T) { - repo := setupGitRepo(t) - oldwd := mustChdir(t, repo) - t.Cleanup(func() { _ = os.Chdir(oldwd) }) - gitCmd(t, repo, "config", "drs.remote.origin.token", "test-token") - gitCmd(t, repo, "config", "drs.remote.origin.username", "alice") - gitCmd(t, repo, "config", "drs.remote.origin.password", "secret") - - header, ok := resolveRemoteAuthHeader("origin") - if !ok { - t.Fatalf("expected auth header") - } - if header != "Bearer test-token" { - t.Fatalf("expected bearer token to win, got %q", header) - } -} - -func TestResolveRemoteAuthHeaderBasicAuthRequiresBothFields(t *testing.T) { - repo := setupGitRepo(t) - oldwd := mustChdir(t, repo) - t.Cleanup(func() { _ = os.Chdir(oldwd) }) - gitCmd(t, repo, "config", "drs.remote.origin.username", "alice") - - header, ok := resolveRemoteAuthHeader("origin") - if ok { - t.Fatalf("expected no auth header, got %q", header) - } - if header != "" { - t.Fatalf("expected empty header, got %q", header) - } -} - -func TestSubmitPendingLFSMetaRequestWiringBasicAuth(t *testing.T) { - repo := setupGitRepo(t) - oldwd := mustChdir(t, repo) - t.Cleanup(func() { _ = os.Chdir(oldwd) }) - gitCmd(t, repo, "config", "drs.remote.origin.username", "alice") - gitCmd(t, repo, "config", "drs.remote.origin.password", "secret") - - oid := strings.Repeat("d", 64) - name := "obj-name" - if err := drsobject.WriteObject(".git/drs/lfs/objects", &drsapi.DrsObject{ - Id: "drs://local:obj-id", - Name: ptrString(name), - Size: 123, - Checksums: []drsapi.Checksum{ - {Type: "sha256", Checksum: oid}, - }, - }, oid); err != nil { - t.Fatalf("write drs object: %v", err) - } - - var gotAuth string - restoreClient := stubPendingMetadataClient(t, func(req *http.Request) (*http.Response, error) { - gotAuth = req.Header.Get("Authorization") - return &http.Response{ - StatusCode: http.StatusOK, - Header: make(http.Header), - Body: io.NopCloser(strings.NewReader("{}")), - Request: req, - }, nil - }) - t.Cleanup(restoreClient) - - logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - err := submitPendingLFSMeta( - context.Background(), - config.Remote("origin"), - "https://example.test", - map[string]lfs.LfsFileInfo{"file.bin": {Oid: oid}}, - logger, - ) - if err != nil { - t.Fatalf("submitPendingLFSMeta: %v", err) - } - - want := "Basic " + base64.StdEncoding.EncodeToString([]byte("alice:secret")) - if gotAuth != want { - t.Fatalf("expected basic auth header %q, got %q", want, gotAuth) - } -} - -type errReader struct{} - -func (errReader) Read([]byte) (int, error) { - return 0, io.ErrUnexpectedEOF -} - -func setupGitRepo(t *testing.T) string { - t.Helper() - dir := t.TempDir() - gitCmd(t, dir, "init") - gitCmd(t, dir, "config", "user.email", "test@example.com") - gitCmd(t, dir, "config", "user.name", "Test User") - return dir -} - -func gitCmd(t *testing.T, dir string, args ...string) { - t.Helper() - cmd := exec.Command("git", args...) - cmd.Dir = dir - out, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("git %s failed: %v (%s)", strings.Join(args, " "), err, string(out)) - } -} - -func gitOutputString(t *testing.T, dir string, args ...string) string { - t.Helper() - cmd := exec.Command("git", args...) - cmd.Dir = dir - out, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("git %s failed: %v (%s)", strings.Join(args, " "), err, string(out)) - } - return strings.TrimSpace(string(out)) -} - -func writeJSON(t *testing.T, path string, value any) { - t.Helper() - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - t.Fatalf("mkdir: %v", err) - } - data, err := json.Marshal(value) - if err != nil { - t.Fatalf("marshal: %v", err) - } - if err := os.WriteFile(path, data, 0o644); err != nil { - t.Fatalf("write: %v", err) - } -} - -func mustChdir(t *testing.T, dir string) string { - t.Helper() - old, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd: %v", err) - } - if err := os.Chdir(dir); err != nil { - t.Fatalf("Chdir(%s): %v", dir, err) - } - return old -} - -func stubPendingMetadataClient(t *testing.T, respond func(*http.Request) (*http.Response, error)) func() { - t.Helper() - orig := pendingMetadataClientFactory - pendingMetadataClientFactory = func() *http.Client { - return &http.Client{Transport: roundTripperFunc(respond)} - } - return func() { - pendingMetadataClientFactory = orig - } -} - -type roundTripperFunc func(*http.Request) (*http.Response, error) - -func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { - return f(req) -} - -func ptrString(s string) *string { return &s } diff --git a/cmd/pull/main.go b/cmd/pull/main.go index cf352d73..76553da1 100644 --- a/cmd/pull/main.go +++ b/cmd/pull/main.go @@ -2,31 +2,46 @@ package pull import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" + "io" + "log/slog" "net/url" "os" - "os/exec" + "path/filepath" + "regexp" + "sort" "strings" - "github.com/bytedance/sonic" - "github.com/calypr/git-drs/internal/common" "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" - "github.com/calypr/git-drs/internal/drsremote" + "github.com/calypr/git-drs/internal/gitrepo" "github.com/calypr/git-drs/internal/lfs" + "github.com/calypr/git-drs/internal/lookup" + "github.com/calypr/git-drs/internal/remoteruntime" + internaltransfer "github.com/calypr/git-drs/internal/transfer" drsapi "github.com/calypr/syfon/apigen/client/drs" + sycommon "github.com/calypr/syfon/client/common" "github.com/spf13/cobra" ) -var runCommand = func(name string, args ...string) ([]byte, error) { - cmd := exec.Command(name, args...) - return cmd.CombinedOutput() -} +var includePatterns []string +var dryRun bool + +var ( + loadCfg = config.LoadConfig + resolveRemote = func(cfg *config.Config, name string) (config.Remote, error) { return cfg.GetRemoteOrDefault(name) } + newRemoteClient = func(cfg *config.Config, remote config.Remote, logger *slog.Logger) (*remoteruntime.GitContext, error) { + return remoteruntime.New(cfg, remote, logger) + } + loadWorktreeInventory = lfs.GetTrackedLfsFiles +) var Cmd = &cobra.Command{ Use: "pull [remote-name]", - Short: "Pull using the standard Git + Git LFS flow", - Long: "Pull using the standard Git + Git LFS flow (git pull, git lfs pull, git lfs checkout).", + Short: "Download DRS pointer file content into the current checkout", + Long: "Hydrate DRS/Git-LFS pointer files in the current checkout. By default this mirrors git lfs pull semantics for the worktree rather than running git pull.", Args: func(cmd *cobra.Command, args []string) error { if len(args) > 1 { cmd.SilenceUsage = false @@ -34,10 +49,10 @@ var Cmd = &cobra.Command{ } return nil }, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(cmd *cobra.Command, args []string) (retErr error) { logg := drslog.GetLogger() - cfg, err := config.LoadConfig() + cfg, err := loadCfg() if err != nil { return fmt.Errorf("error loading config: %v", err) } @@ -46,55 +61,59 @@ var Cmd = &cobra.Command{ if len(args) > 0 { remote = config.Remote(args[0]) } else { - remote, err = cfg.GetDefaultRemote() + remote, err = resolveRemote(cfg, "") if err != nil { logg.Error(fmt.Sprintf("Error getting remote: %v", err)) return err } } - drsCtx, err := cfg.GetRemoteClient(remote, logg) + drsCtx, err := newRemoteClient(cfg, remote, logg) if err != nil { logg.Error(fmt.Sprintf("error creating DRS client: %s", err)) return err } - _ = drsCtx // Remote validation only. - if out, err := runCommand("git", "pull", string(remote)); err != nil { - msg := strings.TrimSpace(string(out)) - if msg == "" { - msg = err.Error() - } - return fmt.Errorf("git pull failed for remote %q: %s", remote, msg) + inventory, err := loadWorktreeInventory(logg) + if err != nil { + return fmt.Errorf("failed to discover pointer files in worktree: %w", err) } - - var parsed struct { - Files []lfs.LfsFileInfo `json:"files"` + pointers := collectPointerFiles(inventory, includePatterns) + if len(pointers) == 0 { + logg.Debug("no matching pointer files to hydrate") + return nil } - out, err := runCommand("git", "lfs", "ls-files", "--json") - if err != nil { - msg := commandMessage(out, err) - if !isMissingGitLFS(msg) { - return fmt.Errorf("git lfs ls-files failed: %s", msg) - } - lfsFiles, inventoryErr := lfs.GetAllLfsFiles(string(remote), "", []string{"HEAD"}, logg) - if inventoryErr != nil { - return fmt.Errorf("git lfs ls-files failed: %s; fallback inventory failed: %w", msg, inventoryErr) + + progress := internaltransfer.NewPullProgressRenderer(os.Stderr) + progress.OnPlan(toPullFiles(pointers)) + defer func() { + if finishErr := progress.Finish(); retErr == nil && finishErr != nil { + retErr = fmt.Errorf("finalize pull progress: %w", finishErr) } - parsed.Files = make([]lfs.LfsFileInfo, 0, len(lfsFiles)) - for _, f := range lfsFiles { - parsed.Files = append(parsed.Files, f) + }() + + if dryRun { + for _, f := range pointers { + if _, err := fmt.Fprintln(cmd.OutOrStdout(), f.Name); err != nil { + return err + } } - } else if err := lfsjsonUnmarshal(out, &parsed); err != nil { - return fmt.Errorf("failed to parse git lfs ls-files output: %w", err) + return nil } ctx := context.Background() - missingOIDs := make([]string, 0, len(parsed.Files)) - seenMissing := make(map[string]struct{}, len(parsed.Files)) - for _, f := range parsed.Files { - if f.Downloaded { + missingOIDs := make([]string, 0, len(pointers)) + seenMissing := make(map[string]struct{}, len(pointers)) + for _, f := range pointers { + cachePath, err := lfs.ObjectPath(gitrepo.LFSObjectsPath, f.Oid) + if err != nil { + return fmt.Errorf("failed to resolve LFS object path for %s: %w", f.Oid, err) + } + state, err := inspectCachedObject(cachePath, f.Oid, f.Size) + if err == nil && state.complete { continue + } else if err != nil { + return fmt.Errorf("failed to stat cached object for %s: %w", f.Oid, err) } if _, seen := seenMissing[f.Oid]; seen { continue @@ -106,7 +125,7 @@ var Cmd = &cobra.Command{ if len(missingOIDs) > 0 { prefetched := make(map[string]drsapi.DrsObject, len(missingOIDs)) for _, oid := range missingOIDs { - recs, err := drsremote.ObjectsByHashForScope(ctx, drsCtx, oid) + recs, err := lookup.ObjectsByHashForScope(ctx, drsCtx, oid) if err != nil || len(recs) == 0 { continue } @@ -124,47 +143,51 @@ var Cmd = &cobra.Command{ for _, obj := range prefetched { objects = append(objects, obj) } - if resolved, err := drsremote.BulkAccessURLsForObjects(ctx, drsCtx, objects); err == nil { + if resolved, err := internaltransfer.BulkAccessURLsForObjects(ctx, drsCtx, objects); err == nil { prefetchedAccess = resolved logg.Debug(fmt.Sprintf("bulk access resolved %d URLs for pull", len(prefetchedAccess))) } else { logg.Debug(fmt.Sprintf("bulk access prefetch failed; continuing per-object: %v", err)) } } - for _, f := range parsed.Files { - if f.Downloaded { - continue - } - dstPath, err := lfs.ObjectPath(common.LFS_OBJS_PATH, f.Oid) + for _, f := range pointers { + dstPath, err := lfs.ObjectPath(gitrepo.LFSObjectsPath, f.Oid) if err != nil { return fmt.Errorf("failed to resolve LFS object path for %s: %w", f.Oid, err) } + state, err := inspectCachedObject(dstPath, f.Oid, f.Size) + if err == nil && state.complete { + continue + } else if err != nil { + return fmt.Errorf("failed to stat cache path %s: %w", dstPath, err) + } + if state.exists { + if err := os.Remove(dstPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove incomplete cached object %s: %w", dstPath, err) + } + } + progress.OnDownloadStart(toPullFile(f)) + downloadCtx := progressContextForPointer(ctx, progress, f) if obj, ok := prefetched[f.Oid]; ok { if accessURL, ok := prefetchedAccess[obj.Id]; ok { objCopy := obj - if err := drsremote.DownloadResolvedToCachePath(ctx, drsCtx, f.Oid, dstPath, &objCopy, &accessURL); err != nil { + if err := internaltransfer.DownloadResolvedToCachePath(downloadCtx, drsCtx, f.Oid, dstPath, &objCopy, &accessURL); err != nil { debugCtx := buildPullDownloadDebugContext(ctx, drsCtx, f.Oid) return fmt.Errorf("failed to download oid %s to %s: %w\npull-debug: %s", f.Oid, dstPath, err, debugCtx) } continue } } - if err := drsremote.DownloadToCachePath(ctx, drsCtx, logg, f.Oid, dstPath); err != nil { + if err := internaltransfer.DownloadToCachePath(downloadCtx, drsCtx, f.Oid, dstPath); err != nil { debugCtx := buildPullDownloadDebugContext(ctx, drsCtx, f.Oid) return fmt.Errorf("failed to download oid %s to %s: %w\npull-debug: %s", f.Oid, dstPath, err, debugCtx) } } } else { - logg.Debug("no missing LFS objects to download") + logg.Debug("no missing pointer objects to download") } - if out, err := runCommand("git", "lfs", "checkout"); err != nil { - msg := commandMessage(out, err) - if !isMissingGitLFS(msg) { - return fmt.Errorf("git lfs checkout failed: %s", msg) - } - } - if err := checkoutDownloadedFiles(parsed.Files); err != nil { + if err := checkoutDownloadedFiles(pointers, progress); err != nil { return err } @@ -172,44 +195,203 @@ var Cmd = &cobra.Command{ }, } -func commandMessage(out []byte, err error) string { - msg := strings.TrimSpace(string(out)) - if msg == "" && err != nil { - msg = err.Error() +type pointerFile struct { + Name string + Oid string + Size int64 +} + +func collectPointerFiles(inventory map[string]lfs.LfsFileInfo, patterns []string) []pointerFile { + keys := make([]string, 0, len(inventory)) + for path := range inventory { + if !matchesAnyPattern(path, patterns) { + continue + } + keys = append(keys, path) + } + sort.Strings(keys) + + files := make([]pointerFile, 0, len(keys)) + for _, path := range keys { + info := inventory[path] + files = append(files, pointerFile{Name: path, Oid: info.Oid, Size: info.Size}) + } + return files +} + +func progressContextForPointer(ctx context.Context, progress *internaltransfer.PullProgressRenderer, file pointerFile) context.Context { + ctx = sycommon.WithOid(ctx, file.Name) + return sycommon.WithProgress(ctx, func(ev sycommon.ProgressEvent) error { + if ev.Event != "progress" { + return nil + } + progress.OnDownloadProgress(file.Name, ev.BytesSoFar, file.Size) + return nil + }) +} + +func toPullFiles(files []pointerFile) []internaltransfer.PullFile { + out := make([]internaltransfer.PullFile, 0, len(files)) + for _, file := range files { + out = append(out, toPullFile(file)) + } + return out +} + +func toPullFile(file pointerFile) internaltransfer.PullFile { + return internaltransfer.PullFile{Name: file.Name, Oid: file.Oid, Size: file.Size} +} + +func matchesAnyPattern(path string, patterns []string) bool { + if len(patterns) == 0 { + return true + } + normalized := filepath.ToSlash(filepath.Clean(path)) + for _, pattern := range patterns { + pattern = strings.TrimSpace(pattern) + if pattern == "" { + continue + } + if matchesPattern(normalized, pattern) { + return true + } + } + return false +} + +type cachedObjectState struct { + exists bool + complete bool +} + +func inspectCachedObject(path, expectedOID string, expectedSize int64) (cachedObjectState, error) { + var state cachedObjectState + info, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return state, nil + } + return state, err + } + state.exists = true + if info.IsDir() { + return state, fmt.Errorf("cached object path is a directory: %s", path) + } + if expectedSize > 0 && info.Size() != expectedSize { + return state, nil + } + if expectedSize <= 0 && info.Size() <= 0 { + return state, nil + } + if strings.TrimSpace(expectedOID) == "" { + state.complete = true + return state, nil + } + + actualOID, err := calculateFileSHA256(path) + if err != nil { + return state, err + } + state.complete = strings.EqualFold(strings.TrimPrefix(expectedOID, "sha256:"), actualOID) + return state, nil +} + +func calculateFileSHA256(path string) (string, error) { + file, err := os.Open(path) + if err != nil { + return "", err + } + defer file.Close() + + hasher := sha256.New() + if _, err := io.Copy(hasher, file); err != nil { + return "", err + } + return hex.EncodeToString(hasher.Sum(nil)), nil +} + +func matchesPattern(path, pattern string) bool { + pattern = filepath.ToSlash(filepath.Clean(pattern)) + if !strings.ContainsAny(pattern, "*?[") { + return path == pattern + } + re, err := regexp.Compile(globToRegexp(pattern)) + if err != nil { + return false } - return msg + return re.MatchString(path) } -func isMissingGitLFS(msg string) bool { - return strings.Contains(msg, "git: 'lfs' is not a git command") +func globToRegexp(pattern string) string { + var b strings.Builder + b.WriteString("^") + for i := 0; i < len(pattern); i++ { + ch := pattern[i] + switch ch { + case '*': + if i+1 < len(pattern) && pattern[i+1] == '*' { + b.WriteString(".*") + i++ + continue + } + b.WriteString(`[^/]*`) + case '?': + b.WriteString(`[^/]`) + case '.', '+', '(', ')', '|', '^', '$', '{', '}', '[', ']', '\\': + b.WriteByte('\\') + b.WriteByte(ch) + default: + b.WriteByte(ch) + } + } + b.WriteString("$") + return b.String() } -func checkoutDownloadedFiles(files []lfs.LfsFileInfo) error { +func checkoutDownloadedFiles(files []pointerFile, progress *internaltransfer.PullProgressRenderer) error { for _, f := range files { if strings.TrimSpace(f.Name) == "" || strings.TrimSpace(f.Oid) == "" { continue } - srcPath, err := lfs.ObjectPath(common.LFS_OBJS_PATH, f.Oid) + srcPath, err := lfs.ObjectPath(gitrepo.LFSObjectsPath, f.Oid) if err != nil { return fmt.Errorf("failed to resolve cached object for %s: %w", f.Oid, err) } - payload, err := os.ReadFile(srcPath) + src, err := os.Open(srcPath) if err != nil { return fmt.Errorf("failed to read cached object %s: %w", srcPath, err) } - if err := os.WriteFile(f.Name, payload, 0o644); err != nil { + progress.OnCheckoutStart(toPullFile(f)) + if dir := filepath.Dir(f.Name); dir != "." { + if err := os.MkdirAll(dir, 0o755); err != nil { + src.Close() + return fmt.Errorf("failed to create directory for %s: %w", f.Name, err) + } + } + dst, err := os.OpenFile(f.Name, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) + if err != nil { + src.Close() return fmt.Errorf("failed to checkout %s: %w", f.Name, err) } + if _, err := io.Copy(dst, src); err != nil { + dst.Close() + src.Close() + return fmt.Errorf("failed to checkout %s: %w", f.Name, err) + } + if err := dst.Close(); err != nil { + src.Close() + return fmt.Errorf("failed to finalize checkout for %s: %w", f.Name, err) + } + if err := src.Close(); err != nil { + return fmt.Errorf("failed to close cached object %s: %w", srcPath, err) + } + progress.OnCompleted(toPullFile(f)) } return nil } -var lfsjsonUnmarshal = func(data []byte, v any) error { - return sonic.ConfigFastest.Unmarshal(data, v) -} - -func buildPullDownloadDebugContext(ctx context.Context, drsCtx *config.GitContext, oid string) string { - recs, err := drsremote.ObjectsByHashForScope(ctx, drsCtx, oid) +func buildPullDownloadDebugContext(ctx context.Context, drsCtx *remoteruntime.GitContext, oid string) string { + recs, err := lookup.ObjectsByHashForScope(ctx, drsCtx, oid) if err != nil { return fmt.Sprintf("oid=%s query_error=%v", oid, err) } @@ -242,3 +424,8 @@ func buildPullDownloadDebugContext(ctx context.Context, drsCtx *config.GitContex } return fmt.Sprintf("oid=%s did=%s size=%d access_methods=%s", oid, strings.TrimSpace(match.Id), match.Size, strings.Join(methods, ", ")) } + +func init() { + Cmd.Flags().StringArrayVarP(&includePatterns, "include", "I", nil, "include pathspec/glob pattern(s)") + Cmd.Flags().BoolVar(&dryRun, "dry-run", false, "list matching pointer files without downloading them") +} diff --git a/cmd/pull/pull_test.go b/cmd/pull/pull_test.go index 41c999ac..b44f2298 100644 --- a/cmd/pull/pull_test.go +++ b/cmd/pull/pull_test.go @@ -1,34 +1,213 @@ package pull import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "log/slog" + "os" + "os/exec" + "path/filepath" "testing" "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/testutils" - "github.com/stretchr/testify/assert" + "github.com/calypr/git-drs/internal/drslog" + "github.com/calypr/git-drs/internal/lfs" + "github.com/calypr/git-drs/internal/remoteruntime" ) -func TestPullCmdArgs(t *testing.T) { - err := Cmd.Args(Cmd, []string{}) - assert.NoError(t, err) +func resetPullFlagsForTest() { + includePatterns = nil + dryRun = false +} + +func TestCollectPointerFilesFiltersAndSorts(t *testing.T) { + resetPullFlagsForTest() - err = Cmd.Args(Cmd, []string{"origin"}) - assert.NoError(t, err) + inventory := map[string]lfs.LfsFileInfo{ + "data/b.bin": {Name: "data/b.bin", Oid: "bbbb", Size: 2}, + "data/a.bin": {Name: "data/a.bin", Oid: "aaaa", Size: 1}, + "misc/c.bin": {Name: "misc/c.bin", Oid: "cccc", Size: 3}, + } - err = Cmd.Args(Cmd, []string{"origin", "extra"}) - assert.Error(t, err) + files := collectPointerFiles(inventory, []string{"data/**"}) + if len(files) != 2 { + t.Fatalf("expected 2 files, got %d", len(files)) + } + if files[0].Name != "data/a.bin" || files[1].Name != "data/b.bin" { + t.Fatalf("unexpected file order: %+v", files) + } } -func TestPullRun_LoadConfigError(t *testing.T) { - _ = testutils.SetupTestGitRepo(t) - err := Cmd.RunE(Cmd, []string{}) - assert.Error(t, err) +func TestPullDryRunListsMatchingPaths(t *testing.T) { + resetPullFlagsForTest() + + oldLoadCfg := loadCfg + oldResolveRemote := resolveRemote + oldNewRemoteClient := newRemoteClient + oldInventory := loadWorktreeInventory + t.Cleanup(func() { + loadCfg = oldLoadCfg + resolveRemote = oldResolveRemote + newRemoteClient = oldNewRemoteClient + loadWorktreeInventory = oldInventory + }) + + loadCfg = func() (*config.Config, error) { return &config.Config{}, nil } + resolveRemote = func(cfg *config.Config, name string) (config.Remote, error) { return config.Remote("origin"), nil } + newRemoteClient = func(cfg *config.Config, remote config.Remote, logger *slog.Logger) (*remoteruntime.GitContext, error) { + return &remoteruntime.GitContext{}, nil + } + loadWorktreeInventory = func(_ *slog.Logger) (map[string]lfs.LfsFileInfo, error) { + return map[string]lfs.LfsFileInfo{ + "data/a.bin": {Name: "data/a.bin", Oid: "aaaa", Size: 1}, + "misc/b.bin": {Name: "misc/b.bin", Oid: "bbbb", Size: 2}, + }, nil + } + + includePatterns = []string{"data/**"} + dryRun = true + + var out bytes.Buffer + Cmd.SetOut(&out) + Cmd.SetErr(&out) + Cmd.SetArgs([]string{"--dry-run"}) + t.Cleanup(func() { + Cmd.SetOut(nil) + Cmd.SetErr(nil) + Cmd.SetArgs(nil) + resetPullFlagsForTest() + }) + + if err := Cmd.RunE(Cmd, []string{}); err != nil { + t.Fatalf("RunE returned error: %v", err) + } + if got := out.String(); got != "data/a.bin\n" { + t.Fatalf("unexpected dry-run output: %q", got) + } } -func TestPullRun_DefaultRemoteError(t *testing.T) { - tmpDir := testutils.SetupTestGitRepo(t) - testutils.CreateTestConfig(t, tmpDir, &config.Config{}) +func TestPullUsesTrackedInventoryForHydratedFiles(t *testing.T) { + repo := t.TempDir() + runGitCmdTest(t, repo, "init") + runGitCmdTest(t, repo, "config", "user.email", "test@example.com") + runGitCmdTest(t, repo, "config", "user.name", "Test User") + runGitCmdTest(t, repo, "config", "filter.drs.clean", "cat") + runGitCmdTest(t, repo, "config", "filter.drs.smudge", "cat") + runGitCmdTest(t, repo, "config", "filter.drs.process", "cat") + runGitCmdTest(t, repo, "config", "filter.drs.required", "false") + + attrPath := filepath.Join(repo, ".gitattributes") + if err := os.WriteFile(attrPath, []byte("*.dat filter=drs diff=drs merge=drs -text\n"), 0o644); err != nil { + t.Fatalf("write .gitattributes: %v", err) + } + + oid := "dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd" + pointerPath := filepath.Join(repo, "data", "hydrated.dat") + writePointerFile(t, pointerPath, oid, "321") + + runGitCmdTest(t, repo, "add", ".") + runGitCmdTest(t, repo, "commit", "-m", "commit tracked pointer") + + if err := os.WriteFile(pointerPath, []byte("localized payload"), 0o644); err != nil { + t.Fatalf("hydrate tracked file: %v", err) + } + + oldWD, err := os.Getwd() + if err != nil { + t.Fatalf("getwd: %v", err) + } + if err := os.Chdir(repo); err != nil { + t.Fatalf("chdir repo: %v", err) + } + t.Cleanup(func() { + _ = os.Chdir(oldWD) + }) + + files, err := loadWorktreeInventory(drslog.NewNoOpLogger()) + if err != nil { + t.Fatalf("loadWorktreeInventory error: %v", err) + } + + info, ok := files["data/hydrated.dat"] + if !ok { + t.Fatal("expected hydrated tracked file in pull inventory") + } + if info.Oid != oid || info.Size != 321 { + t.Fatalf("unexpected hydrated tracked info: %+v", info) + } +} + +func TestInspectCachedObject(t *testing.T) { + tmpDir := t.TempDir() + objectPath := filepath.Join(tmpDir, "obj") + + state, err := inspectCachedObject(objectPath, "abc", 10) + if err != nil { + t.Fatalf("missing file returned error: %v", err) + } + if state.exists || state.complete { + t.Fatal("missing file should not be complete") + } + + if err := os.WriteFile(objectPath, []byte("12345"), 0o644); err != nil { + t.Fatalf("write partial object: %v", err) + } + state, err = inspectCachedObject(objectPath, "abc", 10) + if err != nil { + t.Fatalf("partial file returned error: %v", err) + } + if !state.exists { + t.Fatal("partial file should exist") + } + if state.complete { + t.Fatal("partial file should not be complete") + } + + fullContent := []byte("1234567890") + sum := sha256.Sum256(fullContent) + oid := hex.EncodeToString(sum[:]) + if err := os.WriteFile(objectPath, fullContent, 0o644); err != nil { + t.Fatalf("write full object: %v", err) + } + state, err = inspectCachedObject(objectPath, oid, 10) + if err != nil { + t.Fatalf("complete file returned error: %v", err) + } + if !state.complete { + t.Fatal("full file should be complete") + } + + if err := os.WriteFile(objectPath, []byte("abcdefghij"), 0o644); err != nil { + t.Fatalf("write same-size corrupt object: %v", err) + } + state, err = inspectCachedObject(objectPath, oid, 10) + if err != nil { + t.Fatalf("same-size corrupt file returned error: %v", err) + } + if state.complete { + t.Fatal("same-size corrupt file should not be complete") + } +} + +func writePointerFile(t *testing.T, path, oid, size string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir pointer dir: %v", err) + } + content := "version https://git-lfs.github.com/spec/v1\n" + + "oid sha256:" + oid + "\n" + + "size " + size + "\n" + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("write pointer file: %v", err) + } +} - err := Cmd.RunE(Cmd, []string{}) - assert.Error(t, err) +func runGitCmdTest(t *testing.T, dir string, args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git %v failed: %v\n%s", args, err, string(out)) + } } diff --git a/cmd/push/main.go b/cmd/push/main.go index 4d0445ca..05654c25 100644 --- a/cmd/push/main.go +++ b/cmd/push/main.go @@ -3,27 +3,34 @@ package push import ( "context" "fmt" + "os" "os/exec" + "sort" "strings" "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/lfs" - "github.com/calypr/git-drs/internal/pushsync" + "github.com/calypr/git-drs/internal/remoteruntime" + internaltransfer "github.com/calypr/git-drs/internal/transfer" "github.com/spf13/cobra" ) var pushWithHooks bool +var pushForceUpload bool var runCommand = func(name string, args ...string) ([]byte, error) { cmd := exec.Command(name, args...) return cmd.CombinedOutput() } +var gitOutputFn = gitOutput +var getRemoteMergeBaseFn = getRemoteMergeBase + var Cmd = &cobra.Command{ Use: "push [remote-name]", Short: "Upload/register DRS objects and push Git refs", - Long: "Performs git-drs managed upload/register flow (multipart for large files) and then runs git push (without pre-push hooks by default).", + Long: "Performs git-drs managed upload/register flow (multipart for large files) and then runs git push.", Args: func(cmd *cobra.Command, args []string) error { if len(args) > 1 { cmd.SilenceUsage = false @@ -31,13 +38,18 @@ var Cmd = &cobra.Command{ } return nil }, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(cmd *cobra.Command, args []string) (retErr error) { + fmt.Fprintln(os.Stderr, "DEBUG: ENTERING RunE for push") myLogger := drslog.GetLogger() + ctx := context.Background() + fmt.Fprintln(os.Stderr, "DEBUG: Loading config...") cfg, err := config.LoadConfig() if err != nil { + fmt.Fprintln(os.Stderr, "DEBUG: Failed to load config:", err) myLogger.Debug(fmt.Sprintf("Error loading config: %v", err)) return err } + fmt.Fprintln(os.Stderr, "DEBUG: Config loaded successfully") var remote config.Remote if len(args) > 0 { @@ -50,38 +62,174 @@ var Cmd = &cobra.Command{ } } - drsClient, err := cfg.GetRemoteClient(remote, myLogger) + fmt.Fprintln(os.Stderr, "DEBUG: Getting remote client for remote:", remote) + drsClient, err := remoteruntime.New(cfg, remote, myLogger) if err != nil { + fmt.Fprintln(os.Stderr, "DEBUG: Failed to get remote client:", err) myLogger.Debug(fmt.Sprintf("Error creating DRS client: %s", err)) return err } - lfsFiles, err := lfs.GetAllLfsFiles(string(remote), "", []string{"HEAD"}, myLogger) + fmt.Fprintln(os.Stderr, "DEBUG: Remote client retrieved. Resolving push refs...") + drsClient.ForceUpload = pushForceUpload + pushRefs, err := currentPushRefUpdates(ctx, string(remote)) + if err != nil { + fmt.Fprintln(os.Stderr, "DEBUG: Failed to resolve push refs:", err) + return fmt.Errorf("failed to resolve pushed refs: %w", err) + } + fmt.Fprintln(os.Stderr, "DEBUG: Push refs resolved. Resolving pushed paths...") + pushedPaths, err := listRefUpdatePaths(ctx, pushRefs) if err != nil { + fmt.Fprintln(os.Stderr, "DEBUG: Failed to resolve pushed paths:", err) + return fmt.Errorf("failed to resolve pushed paths: %w", err) + } + fmt.Fprintln(os.Stderr, "DEBUG: Pushed paths resolved. Discovering LFS files...") + lfsFiles, err := lfs.GetLfsFilesForRefPaths("HEAD", pushedPaths, myLogger) + if err != nil { + fmt.Fprintln(os.Stderr, "DEBUG: Failed to discover LFS files:", err) return fmt.Errorf("failed to discover LFS files to push: %w", err) } + fmt.Fprintln(os.Stderr, "DEBUG: LFS files to push resolved. Total files:", len(lfsFiles)) - ctx := context.Background() - if err := pushsync.BatchSyncForPush(drsClient, ctx, lfsFiles); err != nil { + fmt.Fprintln(os.Stderr, "DEBUG: Reconciling committed deletes...") + if _, err := internaltransfer.ReconcileCommittedDeletes(ctx, drsClient, pushRefs, myLogger); err != nil { + fmt.Fprintln(os.Stderr, "DEBUG: Failed to reconcile deletes:", err) + return fmt.Errorf("failed to reconcile deletes: %w", err) + } + fmt.Fprintln(os.Stderr, "DEBUG: Deletes reconciled. Starting BatchSyncForPush...") + progress := internaltransfer.NewUploadProgressRenderer(os.Stderr) + if err := internaltransfer.BatchSyncForPush(drsClient, ctx, lfsFiles, progress); err != nil { + fmt.Fprintln(os.Stderr, "DEBUG: BatchSyncForPush failed:", err) + if finishErr := progress.Finish(); finishErr != nil { + return fmt.Errorf("failed batch register/upload workflow: %w (progress finalize error: %v)", err, finishErr) + } return fmt.Errorf("failed batch register/upload workflow: %w", err) } + if err := progress.Finish(); err != nil { + return fmt.Errorf("finalize upload progress: %w", err) + } + fmt.Fprintln(os.Stderr, "DEBUG: BatchSyncForPush completed successfully") + switch { + case len(lfsFiles) == 0: + fmt.Fprintln(os.Stdout, "No git-drs tracked files found; pushing Git refs only.") + case !progress.HadUploads(): + fmt.Fprintln(os.Stdout, "No DRS payload uploads needed; all tracked objects are already available remotely.") + } pushArgs := []string{"push"} if !pushWithHooks { pushArgs = append(pushArgs, "--no-verify") } pushArgs = append(pushArgs, string(remote)) + fmt.Fprintln(os.Stderr, "DEBUG: Invoking git push with args:", pushArgs) out, err := runCommand("git", pushArgs...) if err != nil { + fmt.Fprintln(os.Stderr, "DEBUG: Git push failed:", err) msg := strings.TrimSpace(string(out)) if msg == "" { msg = err.Error() } return fmt.Errorf("git push failed for remote %q: %s", remote, msg) } + fmt.Fprintln(os.Stderr, "DEBUG: Git push completed successfully") return nil }, } func init() { - Cmd.Flags().BoolVar(&pushWithHooks, "with-hooks", false, "Run git push with local hooks enabled (invokes pre-push)") + Cmd.Flags().BoolVar(&pushWithHooks, "with-hooks", false, "Run git push with local hooks enabled") + Cmd.Flags().BoolVar(&pushForceUpload, "force-upload", false, "Upload payload bytes even when a matching downloadable object already exists remotely") +} + +func currentPushRefUpdates(ctx context.Context, remote string) ([]internaltransfer.RefUpdate, error) { + const zeroSHA = "0000000000000000000000000000000000000000" + head, err := gitOutputFn(ctx, "rev-parse", "HEAD") + if err != nil { + return nil, err + } + var oldSHA string + upstream, err := gitOutputFn(ctx, "rev-parse", "--verify", "@{upstream}") + if err == nil { + oldSHA = upstream + } else { + mb, err := getRemoteMergeBaseFn(ctx, remote, head) + if err == nil && mb != "" { + oldSHA = mb + } else { + oldSHA = zeroSHA + } + } + return []internaltransfer.RefUpdate{{ + OldSHA: oldSHA, + NewSHA: head, + }}, nil +} + +func getRemoteMergeBase(ctx context.Context, remote string, head string) (string, error) { + cmd := exec.CommandContext(ctx, "git", "for-each-ref", "--format=%(refname)", "refs/remotes/"+remote+"/") + out, err := cmd.CombinedOutput() + if err != nil { + return "", err + } + lines := strings.Split(strings.TrimSpace(string(out)), "\n") + var refs []string + for _, line := range lines { + line = strings.TrimSpace(line) + if line != "" && !strings.HasSuffix(line, "/HEAD") { + refs = append(refs, line) + } + } + if len(refs) == 0 { + return "", nil + } + args := append([]string{"merge-base", head}, refs...) + cmdMerge := exec.CommandContext(ctx, "git", args...) + outMerge, err := cmdMerge.CombinedOutput() + if err != nil { + return "", nil + } + return strings.TrimSpace(string(outMerge)), nil +} + +func listRefUpdatePaths(ctx context.Context, refs []internaltransfer.RefUpdate) ([]string, error) { + const zeroSHA = "0000000000000000000000000000000000000000" + set := make(map[string]struct{}) + for _, ref := range refs { + newSHA := strings.TrimSpace(ref.NewSHA) + oldSHA := strings.TrimSpace(ref.OldSHA) + if newSHA == "" || newSHA == zeroSHA { + continue + } + var args []string + if oldSHA == "" || oldSHA == zeroSHA { + args = []string{"ls-tree", "-r", "--name-only", newSHA} + } else { + args = []string{"diff", "--name-only", oldSHA, newSHA} + } + out, err := gitOutputFn(ctx, args...) + if err != nil { + return nil, err + } + for _, line := range strings.Split(strings.TrimSpace(out), "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + set[line] = struct{}{} + } + } + paths := make([]string, 0, len(set)) + for path := range set { + paths = append(paths, path) + } + sort.Strings(paths) + return paths, nil +} + +func gitOutput(ctx context.Context, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, "git", args...) + out, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("git %s: %s", strings.Join(args, " "), strings.TrimSpace(string(out))) + } + return strings.TrimSpace(string(out)), nil } diff --git a/cmd/push/main_test.go b/cmd/push/main_test.go new file mode 100644 index 00000000..9504d83b --- /dev/null +++ b/cmd/push/main_test.go @@ -0,0 +1,110 @@ +package push + +import ( + "context" + "fmt" + "testing" + + internaltransfer "github.com/calypr/git-drs/internal/transfer" +) + +func TestCurrentPushRefUpdatesUsesZeroBaseWhenUpstreamMissing(t *testing.T) { + oldFn := gitOutputFn + gitOutputFn = func(ctx context.Context, args ...string) (string, error) { + switch fmt.Sprint(args) { + case "[rev-parse HEAD]": + return "head-sha", nil + case "[rev-parse --verify @{upstream}]": + return "", fmt.Errorf("git rev-parse --verify @{upstream}: fatal: no upstream configured") + default: + t.Fatalf("unexpected git args: %v", args) + return "", nil + } + } + t.Cleanup(func() { gitOutputFn = oldFn }) + + oldMbFn := getRemoteMergeBaseFn + getRemoteMergeBaseFn = func(ctx context.Context, remote string, head string) (string, error) { + return "", nil + } + t.Cleanup(func() { getRemoteMergeBaseFn = oldMbFn }) + + got, err := currentPushRefUpdates(context.Background(), "origin") + if err != nil { + t.Fatalf("currentPushRefUpdates returned error: %v", err) + } + if len(got) != 1 || got[0].OldSHA != "0000000000000000000000000000000000000000" || got[0].NewSHA != "head-sha" { + t.Fatalf("unexpected push refs: %+v", got) + } +} + +func TestCurrentPushRefUpdatesUsesRemoteMergeBaseWhenUpstreamMissing(t *testing.T) { + oldFn := gitOutputFn + gitOutputFn = func(ctx context.Context, args ...string) (string, error) { + switch fmt.Sprint(args) { + case "[rev-parse HEAD]": + return "head-sha", nil + case "[rev-parse --verify @{upstream}]": + return "", fmt.Errorf("git rev-parse --verify @{upstream}: fatal: no upstream configured") + default: + t.Fatalf("unexpected git args: %v", args) + return "", nil + } + } + t.Cleanup(func() { gitOutputFn = oldFn }) + + oldMbFn := getRemoteMergeBaseFn + getRemoteMergeBaseFn = func(ctx context.Context, remote string, head string) (string, error) { + if remote != "origin" || head != "head-sha" { + t.Fatalf("unexpected getRemoteMergeBase args: remote=%s, head=%s", remote, head) + } + return "merge-base-sha", nil + } + t.Cleanup(func() { getRemoteMergeBaseFn = oldMbFn }) + + got, err := currentPushRefUpdates(context.Background(), "origin") + if err != nil { + t.Fatalf("currentPushRefUpdates returned error: %v", err) + } + if len(got) != 1 || got[0].OldSHA != "merge-base-sha" || got[0].NewSHA != "head-sha" { + t.Fatalf("unexpected push refs: %+v", got) + } +} + +func TestListRefUpdatePathsUsesDiffForExistingBranch(t *testing.T) { + oldFn := gitOutputFn + gitOutputFn = func(ctx context.Context, args ...string) (string, error) { + if fmt.Sprint(args) != "[diff --name-only old-sha new-sha]" { + t.Fatalf("unexpected git args: %v", args) + } + return "a.dat\nb.txt\n", nil + } + t.Cleanup(func() { gitOutputFn = oldFn }) + + got, err := listRefUpdatePaths(context.Background(), []internaltransfer.RefUpdate{{OldSHA: "old-sha", NewSHA: "new-sha"}}) + if err != nil { + t.Fatalf("listRefUpdatePaths returned error: %v", err) + } + if len(got) != 2 || got[0] != "a.dat" || got[1] != "b.txt" { + t.Fatalf("unexpected paths: %+v", got) + } +} + +func TestListRefUpdatePathsUsesLsTreeForFirstPush(t *testing.T) { + oldFn := gitOutputFn + gitOutputFn = func(ctx context.Context, args ...string) (string, error) { + if fmt.Sprint(args) != "[ls-tree -r --name-only new-sha]" { + t.Fatalf("unexpected git args: %v", args) + } + return "a.dat\nb.txt\n", nil + } + t.Cleanup(func() { gitOutputFn = oldFn }) + + got, err := listRefUpdatePaths(context.Background(), []internaltransfer.RefUpdate{{OldSHA: "0000000000000000000000000000000000000000", NewSHA: "new-sha"}}) + if err != nil { + t.Fatalf("listRefUpdatePaths returned error: %v", err) + } + if len(got) != 2 || got[0] != "a.dat" || got[1] != "b.txt" { + t.Fatalf("unexpected paths: %+v", got) + } +} diff --git a/cmd/query/main.go b/cmd/query/main.go index dbda1a26..79028979 100644 --- a/cmd/query/main.go +++ b/cmd/query/main.go @@ -5,10 +5,11 @@ import ( "fmt" "strings" - "github.com/calypr/git-drs/internal/common" + "github.com/bytedance/sonic" "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" - "github.com/calypr/git-drs/internal/drsremote" + "github.com/calypr/git-drs/internal/lookup" + "github.com/calypr/git-drs/internal/remoteruntime" drsapi "github.com/calypr/syfon/apigen/client/drs" "github.com/calypr/syfon/client/hash" "github.com/spf13/cobra" @@ -18,12 +19,12 @@ var remote string var checksum = false var pretty = false -func queryByChecksum(ctx context.Context, gc *config.GitContext, checksum string) ([]drsapi.DrsObject, error) { +func queryByChecksum(ctx context.Context, gc *remoteruntime.GitContext, checksum string) ([]drsapi.DrsObject, error) { hashType := checksumTypeForString(checksum) if hashType != hash.ChecksumTypeSHA256.String() { return nil, fmt.Errorf("checksum lookup currently only supports sha256 (got %q); non-sha256 support is tracked in syfon DRSService.GetObjectsByChecksum", hashType) } - return drsremote.ObjectsByHashForScope(ctx, gc, checksum) + return lookup.ObjectsByHashForScope(ctx, gc, checksum) } func checksumTypeForString(sum string) string { @@ -41,6 +42,23 @@ func checksumTypeForString(sum string) string { } } +func printDRSObject(obj drsapi.DrsObject, pretty bool) error { + var out []byte + var err error + + if pretty { + out, err = sonic.ConfigFastest.MarshalIndent(obj, "", " ") + } else { + out, err = sonic.ConfigFastest.Marshal(obj) + } + if err != nil { + return err + } + + fmt.Printf("%s\n", string(out)) + return nil +} + // Cmd line declaration var Cmd = &cobra.Command{ Use: "query ", @@ -67,7 +85,7 @@ var Cmd = &cobra.Command{ return err } - gc, err := cfg.GetRemoteClient(remoteName, logger) + gc, err := remoteruntime.New(cfg, remoteName, logger) if err != nil { return err } @@ -78,7 +96,7 @@ var Cmd = &cobra.Command{ return err } for _, drsObj := range objs { - if err := common.PrintDRSObject(drsObj, pretty); err != nil { + if err := printDRSObject(drsObj, pretty); err != nil { return err } } @@ -89,7 +107,7 @@ var Cmd = &cobra.Command{ if err != nil { return err } - return common.PrintDRSObject(obj, pretty) + return printDRSObject(obj, pretty) }, } diff --git a/cmd/query/main_test.go b/cmd/query/main_test.go index 25a05808..5f1444c7 100644 --- a/cmd/query/main_test.go +++ b/cmd/query/main_test.go @@ -3,7 +3,6 @@ package query import ( "testing" - "github.com/calypr/git-drs/internal/common" drsapi "github.com/calypr/syfon/apigen/client/drs" ) @@ -27,10 +26,10 @@ func TestChecksumTypeForString(t *testing.T) { func TestPrintDRSObject(t *testing.T) { obj := drsapi.DrsObject{Id: "test-id"} - if err := common.PrintDRSObject(obj, false); err != nil { - t.Fatalf("common.PrintDRSObject failed: %v", err) + if err := printDRSObject(obj, false); err != nil { + t.Fatalf("printDRSObject failed: %v", err) } - if err := common.PrintDRSObject(obj, true); err != nil { - t.Fatalf("common.PrintDRSObject pretty failed: %v", err) + if err := printDRSObject(obj, true); err != nil { + t.Fatalf("printDRSObject pretty failed: %v", err) } } diff --git a/cmd/remote/add/add_test.go b/cmd/remote/add/add_test.go index 2923b5d8..7e7f2d94 100644 --- a/cmd/remote/add/add_test.go +++ b/cmd/remote/add/add_test.go @@ -1,8 +1,14 @@ package add import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" "testing" + bucketapi "github.com/calypr/syfon/apigen/client/bucketapi" "github.com/stretchr/testify/assert" ) @@ -12,5 +18,136 @@ func TestAddCmd(t *testing.T) { } func TestGen3Cmd(t *testing.T) { - assert.Equal(t, "gen3 [remote-name]", Gen3Cmd.Use) + assert.Equal(t, "gen3 [remote-name] ", Gen3Cmd.Use) +} + +func TestParseScopeArg(t *testing.T) { + t.Run("splits org and project on slash", func(t *testing.T) { + org, project, err := parseScopeArg("HTAN_INT/BForePC") + if err != nil { + t.Fatalf("parseScopeArg returned error: %v", err) + } + if org != "HTAN_INT" || project != "BForePC" { + t.Fatalf("unexpected scope parse result: %q/%q", org, project) + } + }) + + t.Run("rejects legacy single token input", func(t *testing.T) { + _, _, err := parseScopeArg("BForePC") + if err == nil { + t.Fatal("expected invalid scope error") + } + }) + + t.Run("rejects empty org or project", func(t *testing.T) { + for _, raw := range []string{"/BForePC", "HTAN_INT/", "HTAN_INT//BForePC"} { + _, _, err := parseScopeArg(raw) + if err == nil { + t.Fatalf("expected invalid scope error for %q", raw) + } + } + }) +} + +func TestResolveBucketScopeFromServer(t *testing.T) { + t.Run("matches project resource", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/data/buckets" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer test-token" { + t.Fatalf("unexpected auth header: %q", got) + } + _, _ = w.Write([]byte(`{"S3_BUCKETS":{"cbds":{"programs":["/organization/HTAN_INT/project/BForePC"]}}}`)) + })) + defer srv.Close() + + scope, err := resolveBucketScopeFromServer(context.Background(), srv.URL, "test-token", "HTAN_INT", "BForePC", "") + if err != nil { + t.Fatalf("resolveBucketScopeFromServer returned error: %v", err) + } + if scope.Bucket != "cbds" { + t.Fatalf("unexpected bucket: %+v", scope) + } + }) + + t.Run("falls back to org resource", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"S3_BUCKETS":{"cbds":{"programs":["/organization/HTAN_INT"]}}}`)) + })) + defer srv.Close() + + scope, err := resolveBucketScopeFromServer(context.Background(), srv.URL, "test-token", "HTAN_INT", "BForePC", "") + if err != nil { + t.Fatalf("resolveBucketScopeFromServer returned error: %v", err) + } + if scope.Bucket != "cbds" { + t.Fatalf("unexpected bucket: %+v", scope) + } + }) + + t.Run("no match", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := bucketapi.BucketsResponse{S3BUCKETS: map[string]bucketapi.BucketMetadata{ + "cbds": {}, + }} + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Fatalf("encode response: %v", err) + } + })) + defer srv.Close() + + _, err := resolveBucketScopeFromServer(context.Background(), srv.URL, "test-token", "HTAN_INT", "BForePC", "") + if err == nil { + t.Fatal("expected error when no matching bucket is visible") + } + }) + + t.Run("reports ambiguity with candidate buckets", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := bucketapi.BucketsResponse{S3BUCKETS: map[string]bucketapi.BucketMetadata{ + "EllrottLab": {Programs: &[]string{"/organization/Ellrott_Lab/project/hla2vec"}}, + "cbds": {Programs: &[]string{"/organization/Ellrott_Lab/project/hla2vec"}}, + }} + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Fatalf("encode response: %v", err) + } + })) + defer srv.Close() + + _, err := resolveBucketScopeFromServer(context.Background(), srv.URL, "test-token", "Ellrott_Lab", "hla2vec", "") + if err == nil { + t.Fatal("expected ambiguity error") + } + if !strings.Contains(err.Error(), "multiple visible server buckets matched") { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(err.Error(), "EllrottLab, cbds") { + t.Fatalf("expected candidate list in error, got: %v", err) + } + }) + + t.Run("uses selected bucket when ambiguity exists", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := bucketapi.BucketsResponse{S3BUCKETS: map[string]bucketapi.BucketMetadata{ + "EllrottLab": {Programs: &[]string{"/organization/Ellrott_Lab/project/hla2vec"}}, + "cbds": {Programs: &[]string{"/organization/Ellrott_Lab/project/hla2vec"}}, + }} + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Fatalf("encode response: %v", err) + } + })) + defer srv.Close() + + scope, err := resolveBucketScopeFromServer(context.Background(), srv.URL, "test-token", "Ellrott_Lab", "hla2vec", "cbds") + if err != nil { + t.Fatalf("resolveBucketScopeFromServer returned error: %v", err) + } + if scope.Bucket != "cbds" { + t.Fatalf("unexpected bucket: %+v", scope) + } + }) } diff --git a/cmd/remote/add/bucket_lookup.go b/cmd/remote/add/bucket_lookup.go new file mode 100644 index 00000000..4721fe49 --- /dev/null +++ b/cmd/remote/add/bucket_lookup.go @@ -0,0 +1,147 @@ +package add + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "slices" + "strings" + + "github.com/calypr/git-drs/internal/gitrepo" + bucketapi "github.com/calypr/syfon/apigen/client/bucketapi" + syfoncommon "github.com/calypr/syfon/common" +) + +func resolveBucketFromPayload(payload bucketapi.BucketsResponse, organization, project, preferredBucket string) (string, error) { + projectResource, err := syfoncommon.ResourcePath(organization, project) + if err != nil { + return "", err + } + orgResource, err := syfoncommon.ResourcePath(organization, "") + if err != nil { + return "", err + } + + if bucket, err := chooseVisibleBucket(payload, projectResource, preferredBucket); err != nil || bucket != "" { + return bucket, err + } + if bucket, err := chooseVisibleBucket(payload, orgResource, preferredBucket); err != nil || bucket != "" { + return bucket, err + } + + return "", fmt.Errorf("no visible server bucket matched organization=%q project=%q", organization, project) +} + +func chooseVisibleBucket(payload bucketapi.BucketsResponse, resource, preferredBucket string) (string, error) { + matches := findBucketsByResource(payload, resource) + if len(matches) == 0 { + return "", nil + } + + preferredBucket = strings.TrimSpace(preferredBucket) + if preferredBucket != "" { + for _, bucket := range matches { + if strings.EqualFold(bucket, preferredBucket) { + return bucket, nil + } + } + return "", fmt.Errorf("selected bucket %q does not match resource %q; choose one of: %s", preferredBucket, resource, strings.Join(matches, ", ")) + } + + if len(matches) == 1 { + return matches[0], nil + } + return "", fmt.Errorf("multiple visible server buckets matched resource %q: %s; rerun with --bucket ", resource, strings.Join(matches, ", ")) +} + +func findBucketsByResource(payload bucketapi.BucketsResponse, resource string) []string { + resource = syfoncommon.NormalizeAccessResource(resource) + if resource == "" { + return nil + } + + matches := make([]string, 0) + for bucket, meta := range payload.S3BUCKETS { + if meta.Programs == nil { + continue + } + for _, candidate := range *meta.Programs { + if syfoncommon.NormalizeAccessResource(candidate) == resource { + matches = append(matches, bucket) + break + } + } + } + slices.Sort(matches) + return slices.Compact(matches) +} + +func resolveBucketScopeFromServer(ctx context.Context, endpoint, token, organization, project, preferredBucket string) (gitrepo.ResolvedBucketScope, error) { + if strings.TrimSpace(endpoint) == "" { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("missing API endpoint for server bucket lookup") + } + if strings.TrimSpace(token) == "" { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("missing access token for server bucket lookup") + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimRight(endpoint, "/")+"/data/buckets", nil) + if err != nil { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("build bucket list request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token)) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("request bucket list: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("bucket list failed with status %d", resp.StatusCode) + } + + var payload bucketapi.BucketsResponse + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("decode bucket list response: %w", err) + } + + bucket, err := resolveBucketFromPayload(payload, organization, project, preferredBucket) + if err != nil { + return gitrepo.ResolvedBucketScope{}, err + } + return gitrepo.ResolvedBucketScope{Bucket: bucket}, nil +} + +func resolveBucketScopeFromLocalServer(ctx context.Context, endpoint, username, password, organization, project, preferredBucket string) (gitrepo.ResolvedBucketScope, error) { + if strings.TrimSpace(endpoint) == "" { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("missing API endpoint for server bucket lookup") + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimRight(endpoint, "/")+"/data/buckets", nil) + if err != nil { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("build bucket list request: %w", err) + } + if username != "" || password != "" { + req.SetBasicAuth(username, password) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("request bucket list: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("bucket list failed with status %d", resp.StatusCode) + } + + var payload bucketapi.BucketsResponse + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("decode bucket list response: %w", err) + } + + bucket, err := resolveBucketFromPayload(payload, organization, project, preferredBucket) + if err != nil { + return gitrepo.ResolvedBucketScope{}, err + } + return gitrepo.ResolvedBucketScope{Bucket: bucket}, nil +} diff --git a/cmd/remote/add/gen3.go b/cmd/remote/add/gen3.go index 9f9ceed7..504cd1ad 100644 --- a/cmd/remote/add/gen3.go +++ b/cmd/remote/add/gen3.go @@ -5,71 +5,62 @@ import ( "fmt" "log/slog" "strings" + "time" - "github.com/calypr/data-client/credentials" - "github.com/calypr/git-drs/internal/common" + "github.com/calypr/calypr-cli/conf" + "github.com/calypr/calypr-cli/credentials" + "github.com/calypr/git-drs/cmd/initialize" "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/gitrepo" - conf "github.com/calypr/syfon/client/config" + "github.com/calypr/git-drs/internal/remoteruntime" "github.com/spf13/cobra" ) var Gen3Cmd = &cobra.Command{ - Use: "gen3 [remote-name]", + Use: "gen3 [remote-name] ", Args: func(cmd *cobra.Command, args []string) error { - if len(args) > 1 { + if len(args) < 1 || len(args) > 2 { cmd.SilenceUsage = false - return fmt.Errorf("error: accepts at most 1 argument (remote name), received %d\n\nUsage: %s\n\nSee 'git drs remote add gen3 --help' for more details", len(args), cmd.UseLine()) + return fmt.Errorf("error: expected [remote-name] , received %d arguments\n\nUsage: %s\n\nSee 'git drs remote add gen3 --help' for more details", len(args), cmd.UseLine()) } return nil }, RunE: func(cmd *cobra.Command, args []string) error { logg := drslog.GetLogger() - // make sure at least one of the credentials params is provided - if credFile == "" && fenceToken == "" && len(args) == 0 { - return fmt.Errorf("error: Gen3 requires a credentials file or accessToken to setup project locally. Please provide either a --cred or --token flag. See 'git drs remote add gen3 --help' for more details") - } - remoteName := config.ORIGIN - if len(args) > 0 { + scopeArg := "" + if len(args) == 1 { + scopeArg = args[0] + } else { remoteName = args[0] + scopeArg = args[1] } - err := gen3Init(remoteName, credFile, fenceToken, project, organization, bucket, logg) + err := gen3Init(remoteName, credFile, fenceToken, scopeArg, logg) if err != nil { return fmt.Errorf("error configuring gen3 server: %v", err) } + if noSkipSmudge { + if err := gitrepo.SetGitConfigOptions(map[string]string{"drs.skipsmudge": "false"}); err != nil { + return fmt.Errorf("failed to configure skipsmudge: %w", err) + } + } return nil }, } -func gen3Init(remoteName, credFile, fenceToken, project, organization, bucket string, logg *slog.Logger) error { +func gen3Init(remoteName, credFile, fenceToken, scopeArg string, logg *slog.Logger) error { if remoteName == "" { return fmt.Errorf("remote name is required") } - if project == "" { - return fmt.Errorf("project is required for Gen3 remote") - } - - resolvedBucket := strings.TrimSpace(bucket) - resolvedStoragePrefix := "" - if strings.TrimSpace(organization) != "" { - scope, err := gitrepo.ResolveBucketScope(organization, project, resolvedBucket, "") - if err != nil { - return fmt.Errorf("failed resolving bucket mapping for organization=%q project=%q: %w", organization, project, err) - } - resolvedBucket = strings.TrimSpace(scope.Bucket) - resolvedStoragePrefix = strings.TrimSpace(scope.Prefix) + if err := initialize.EnsureInitialized(logg); err != nil { + return fmt.Errorf("failed to initialize repository: %w", err) } - if resolvedBucket == "" { - if strings.TrimSpace(organization) == "" { - return fmt.Errorf("bucket is required when organization is empty") - } - if strings.TrimSpace(resolvedBucket) == "" { - return fmt.Errorf("bucket is required (or configure mapping first with `git drs bucket add-project --organization %s --project %s --path :///`)", organization, project) - } + organization, project, err := parseScopeArg(scopeArg) + if err != nil { + return err } var accessToken, apiKey, keyID, apiEndpoint string @@ -78,7 +69,7 @@ func gen3Init(remoteName, credFile, fenceToken, project, organization, bucket st case fenceToken != "": accessToken = fenceToken var err error - apiEndpoint, err = common.ParseAPIEndpointFromToken(accessToken) + apiEndpoint, err = config.ParseAPIEndpointFromToken(accessToken) if err != nil { return fmt.Errorf("failed to parse API endpoint from provided access token: %w", err) } @@ -92,20 +83,20 @@ func gen3Init(remoteName, credFile, fenceToken, project, organization, bucket st apiKey = cred.APIKey keyID = cred.KeyID - apiEndpoint, err = common.ParseAPIEndpointFromToken(cred.APIKey) + apiEndpoint, err = config.ParseAPIEndpointFromToken(cred.APIKey) if err != nil { return fmt.Errorf("failed to parse API endpoint from API key in credentials file: %w", err) } default: existing, err := configure.Load(remoteName) - if err == nil { + if err != nil { + return fmt.Errorf("failed to load %s config: %w", remoteName, err) + } else { accessToken = existing.AccessToken apiKey = existing.APIKey keyID = existing.KeyID apiEndpoint = existing.APIEndpoint - } else { - return fmt.Errorf("must provide either --cred or --token (or have existing profile %s)", remoteName) } } @@ -113,52 +104,41 @@ func gen3Init(remoteName, credFile, fenceToken, project, organization, bucket st return fmt.Errorf("could not determine Gen3 API endpoint") } - remoteGen3 := config.RemoteSelect{ - Gen3: &config.Gen3Remote{ - Endpoint: apiEndpoint, - ProjectID: project, - Organization: organization, - Bucket: resolvedBucket, - StoragePrefix: resolvedStoragePrefix, - }, - } - - remote := config.Remote(remoteName) - if _, err := config.UpdateRemote(remote, remoteGen3); err != nil { - return fmt.Errorf("failed to update remote config: %w", err) - } - logg.Debug(fmt.Sprintf("Remote added/updated: %s → %s (project: %s, bucket: %s, storage_prefix: %s)", remoteName, apiEndpoint, project, resolvedBucket, resolvedStoragePrefix)) - - // Step 3: Ensure credential profile is up-to-date (refreshes token if needed) cred := &conf.Credential{ Profile: remoteName, APIEndpoint: apiEndpoint, APIKey: apiKey, KeyID: keyID, AccessToken: accessToken, // may be stale - UseShepherd: "false", // or preserve from existing? + UseShepherd: "false", MinShepherdVersion: "", } - if err := credentials.EnsureValidCredential(context.Background(), cred, logg); err != nil { - return fmt.Errorf("failed to verify/refresh Gen3 credential: %w", err) + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + if err := credentials.EnsureValidCredential(ctx, cred, logg); err != nil { + return fmt.Errorf("failed to verify/refresh Gen3 credential: %w", remoteruntime.WrapCredentialValidationError(remoteName, err)) } - if err := configure.Save(cred); err != nil { - return fmt.Errorf("failed to configure/update Gen3 profile: %w", err) - } - // Configure stock git credential plumbing for lfs + persist the refreshed token locally. - if err := gitrepo.ConfigureCredentialHelperForRepo(); err != nil { - return fmt.Errorf("failed to configure git credential helper: %w", err) + scope, err := gitrepo.ResolveBucketScope(organization, project, "", "") + if err != nil { + scope, err = resolveBucketScopeFromServer(context.Background(), apiEndpoint, strings.TrimSpace(cred.AccessToken), organization, project, selectedBucket) + if err != nil { + return fmt.Errorf("failed resolving bucket mapping for organization=%q project=%q: %w", organization, project, err) + } } - if err := gitrepo.SetRemoteLFSURL(remoteName, apiEndpoint); err != nil { - return fmt.Errorf("failed to set lfs url for remote %s: %w", remoteName, err) + resolvedBucket := strings.TrimSpace(scope.Bucket) + resolvedStoragePrefix := strings.TrimSpace(scope.Prefix) + if resolvedBucket == "" { + return fmt.Errorf("no bucket mapping found for organization=%q project=%q", organization, project) } - if strings.TrimSpace(cred.AccessToken) != "" { - if err := gitrepo.SetRemoteToken(remoteName, strings.TrimSpace(cred.AccessToken)); err != nil { - return fmt.Errorf("failed to persist repo token for remote %s: %w", remoteName, err) - } + + if err := persistGen3Remote(remoteName, organization, project, apiEndpoint, scope, func() error { + return configure.Save(cred) + }); err != nil { + return err } + logg.Debug(fmt.Sprintf("Remote added/updated: %s → %s (project: %s, bucket: %s, storage_prefix: %s)", remoteName, apiEndpoint, project, resolvedBucket, resolvedStoragePrefix)) logg.Debug(fmt.Sprintf("Gen3 profile '%s' configured and token refreshed successfully", remoteName)) return nil diff --git a/cmd/remote/add/init.go b/cmd/remote/add/init.go index 55f848f2..4248b546 100644 --- a/cmd/remote/add/init.go +++ b/cmd/remote/add/init.go @@ -3,14 +3,12 @@ package add import "github.com/spf13/cobra" var ( - apiEndpoint string - bucket string - credFile string - fenceToken string - localPassword string - localUsername string - project string - organization string + credFile string + fenceToken string + selectedBucket string + localPassword string + localUsername string + noSkipSmudge bool ) // Cmd line declaration @@ -20,18 +18,15 @@ var Cmd = &cobra.Command{ } func init() { - Gen3Cmd.Flags().StringVar(&apiEndpoint, "url", "", "[gen3] Specify the API endpoint of the data commons") - Gen3Cmd.Flags().StringVar(&bucket, "bucket", "", "[gen3] Specify the bucket name") - Gen3Cmd.Flags().StringVar(&credFile, "cred", "", "[gen3] Specify the gen3 credential file that you want to use") - Gen3Cmd.Flags().StringVar(&fenceToken, "token", "", "[gen3] Specify the token to be used as a replacement for a credential file for temporary access") - Gen3Cmd.Flags().StringVar(&project, "project", "", "[gen3] Specify the gen3 project ID in the format -") - Gen3Cmd.Flags().StringVar(&organization, "organization", "", "[gen3] Optional organization/program scope (use with --project as project id)") + Gen3Cmd.Flags().StringVar(&credFile, "cred", "", "[gen3] Import a Gen3 credential file into this profile") + Gen3Cmd.Flags().StringVar(&fenceToken, "token", "", "[gen3] Use a temporary bearer token issued from fence") + Gen3Cmd.Flags().StringVar(&selectedBucket, "bucket", "", "[gen3] Select a specific visible bucket when multiple buckets match the scope") + Gen3Cmd.Flags().BoolVar(&noSkipSmudge, "no-skip-smudge", false, "Disable skipping smudge filter (force downloading file contents during checkout)") Cmd.AddCommand(Gen3Cmd) - LocalCmd.Flags().StringVarP(&project, "project", "p", "", "Project ID") - LocalCmd.Flags().StringVar(&bucket, "bucket", "", "Bucket Name") - LocalCmd.Flags().StringVar(&organization, "organization", "", "Organization Name") + LocalCmd.Flags().StringVar(&selectedBucket, "bucket", "", "Select a specific visible bucket when multiple buckets match the scope") LocalCmd.Flags().StringVar(&localUsername, "username", "", "Username for local DRS HTTP basic auth") LocalCmd.Flags().StringVar(&localPassword, "password", "", "Password for local DRS HTTP basic auth") + LocalCmd.Flags().BoolVar(&noSkipSmudge, "no-skip-smudge", false, "Disable skipping smudge filter (force downloading file contents during checkout)") Cmd.AddCommand(LocalCmd) } diff --git a/cmd/remote/add/local.go b/cmd/remote/add/local.go index c0d61b29..8a34b330 100644 --- a/cmd/remote/add/local.go +++ b/cmd/remote/add/local.go @@ -1,56 +1,63 @@ package add import ( + "context" "fmt" "strings" + "github.com/calypr/git-drs/cmd/initialize" "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/gitrepo" "github.com/spf13/cobra" ) var LocalCmd = &cobra.Command{ - Use: "local ", + Use: "local ", Short: "Add a local DRS server", - Long: "Add a local DRS server by specifying its base URL, e.g., http://localhost:8000. Optional --username/--password configures basic auth for git-lfs and helper flows.", - Args: cobra.ExactArgs(2), + Long: "Add a local DRS server by specifying its base URL and scope. Optional --username/--password configures basic auth for helper flows.", + Args: cobra.ExactArgs(3), RunE: func(cmd *cobra.Command, args []string) error { remoteName := args[0] url := args[1] + scopeArg := args[2] + if err := initialize.EnsureInitialized(drslog.GetLogger()); err != nil { + return fmt.Errorf("failed to initialize repository: %w", err) + } if url == "" { return fmt.Errorf("URL cannot be empty") } - - remoteSelect := config.RemoteSelect{ - Local: &config.LocalRemote{ - BaseURL: url, - ProjectID: project, - Bucket: bucket, - Organization: organization, - }, - } - - newConfig, err := config.UpdateRemote(config.Remote(remoteName), remoteSelect) + organization, project, err := parseScopeArg(scopeArg) if err != nil { return err } - if err := gitrepo.SetRemoteLFSURL(remoteName, url); err != nil { - return fmt.Errorf("failed to configure lfs url for remote %q: %w", remoteName, err) + scope, err := gitrepo.ResolveBucketScope(organization, project, "", "") + if err != nil { + scope, err = resolveBucketScopeFromLocalServer(context.Background(), url, strings.TrimSpace(localUsername), strings.TrimSpace(localPassword), organization, project, selectedBucket) + if err != nil { + return fmt.Errorf("failed resolving bucket mapping for organization=%q project=%q: %w", organization, project, err) + } } - if err := gitrepo.ConfigureCredentialHelperForRepo(); err != nil { - return fmt.Errorf("failed to configure git credential helper: %w", err) + resolvedBucket := strings.TrimSpace(scope.Bucket) + if resolvedBucket == "" { + return fmt.Errorf("no bucket mapping found for organization=%q project=%q", organization, project) } - if strings.TrimSpace(localUsername) != "" || strings.TrimSpace(localPassword) != "" { - if strings.TrimSpace(localUsername) == "" || strings.TrimSpace(localPassword) == "" { - return fmt.Errorf("both --username and --password are required when configuring local basic auth") - } - if err := gitrepo.SetRemoteBasicAuth(remoteName, strings.TrimSpace(localUsername), strings.TrimSpace(localPassword)); err != nil { - return fmt.Errorf("failed to configure local basic auth for remote %q: %w", remoteName, err) - } + + newConfig, err := persistLocalRemote(remoteName, url, organization, project, scope) + if err != nil { + return err + } + if err := configureLocalBasicAuth(remoteName, localUsername, localPassword); err != nil { + return err } fmt.Printf("Added remote '%s'. Config: %v\n", remoteName, newConfig.GetRemote(config.Remote(remoteName))) + if noSkipSmudge { + if err := gitrepo.SetGitConfigOptions(map[string]string{"drs.skipsmudge": "false"}); err != nil { + return fmt.Errorf("failed to configure skipsmudge: %w", err) + } + } return nil }, } diff --git a/cmd/remote/add/local_test.go b/cmd/remote/add/local_test.go index 80e908a6..23c3ded3 100644 --- a/cmd/remote/add/local_test.go +++ b/cmd/remote/add/local_test.go @@ -1,14 +1,100 @@ package add import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" "testing" + "github.com/calypr/git-drs/internal/gitrepo" + "github.com/calypr/git-drs/internal/testutils" "github.com/stretchr/testify/assert" ) func TestAddLocalRemote(t *testing.T) { assert.NotNil(t, LocalCmd) - assert.Equal(t, "local ", LocalCmd.Use) + assert.Equal(t, "local ", LocalCmd.Use) assert.NotNil(t, LocalCmd.Flag("username")) assert.NotNil(t, LocalCmd.Flag("password")) + assert.NotNil(t, LocalCmd.Flag("bucket")) + assert.Nil(t, LocalCmd.Flag("organization")) + assert.Nil(t, LocalCmd.Flag("project")) +} + +func TestResolveBucketScopeFromLocalServer(t *testing.T) { + t.Run("matches project resource", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/data/buckets" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + user, pass, ok := r.BasicAuth() + if !ok || user != "drs-user" || pass != "drs-pass" { + t.Fatalf("unexpected basic auth: ok=%v user=%q pass=%q", ok, user, pass) + } + _, _ = w.Write([]byte(`{"S3_BUCKETS":{"cbds":{"programs":["/organization/calypr/project/end_to_end_test"]}}}`)) + })) + defer srv.Close() + + scope, err := resolveBucketScopeFromLocalServer(context.Background(), srv.URL, "drs-user", "drs-pass", "calypr", "end_to_end_test", "") + if err != nil { + t.Fatalf("resolveBucketScopeFromLocalServer returned error: %v", err) + } + if scope.Bucket != "cbds" { + t.Fatalf("unexpected bucket: %+v", scope) + } + }) +} + +func TestLocalRemoteAddEnsuresInitialization(t *testing.T) { + testutils.SetupTestGitRepo(t) + localUsername = "" + localPassword = "" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/data/buckets" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + _, _ = w.Write([]byte(`{"S3_BUCKETS":{"cbds":{"programs":["/organization/calypr/project/end_to_end_test"]}}}`)) + })) + defer srv.Close() + + if err := LocalCmd.RunE(LocalCmd, []string{"origin", srv.URL, "calypr/end_to_end_test"}); err != nil { + t.Fatalf("LocalCmd.RunE returned error: %v", err) + } + + if _, err := os.Stat(gitrepo.DRSDir); err != nil { + t.Fatalf("expected %s to exist: %v", gitrepo.DRSDir, err) + } + + filterProcess, err := gitrepo.GetGitConfigString("filter.drs.process") + if err != nil { + t.Fatalf("GetGitConfigString(filter.drs.process): %v", err) + } + if filterProcess != "git-drs filter" { + t.Fatalf("unexpected filter.drs.process: %q", filterProcess) + } + filterClean, err := gitrepo.GetGitConfigString("filter.drs.clean") + if err != nil { + t.Fatalf("GetGitConfigString(filter.drs.clean): %v", err) + } + if filterClean != "git-drs clean -- %f" { + t.Fatalf("unexpected filter.drs.clean: %q", filterClean) + } + filterSmudge, err := gitrepo.GetGitConfigString("filter.drs.smudge") + if err != nil { + t.Fatalf("GetGitConfigString(filter.drs.smudge): %v", err) + } + if filterSmudge != "git-drs smudge -- %f" { + t.Fatalf("unexpected filter.drs.smudge: %q", filterSmudge) + } + + preCommit, err := os.ReadFile(filepath.Join(".git", "hooks", "pre-commit")) + if err != nil { + t.Fatalf("read pre-commit hook: %v", err) + } + if string(preCommit) == "" { + t.Fatalf("expected pre-commit hook to be installed") + } } diff --git a/cmd/remote/add/persist.go b/cmd/remote/add/persist.go new file mode 100644 index 00000000..bec1e5b5 --- /dev/null +++ b/cmd/remote/add/persist.go @@ -0,0 +1,79 @@ +package add + +import ( + "fmt" + "strings" + + "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/gitrepo" +) + +func persistGen3Remote(remoteName, organization, project, endpoint string, scope gitrepo.ResolvedBucketScope, saveCredential func() error) error { + remote := config.Remote(remoteName) + remoteGen3 := config.RemoteSelect{ + Gen3: &config.Gen3Remote{ + Endpoint: endpoint, + ProjectID: project, + Organization: organization, + Bucket: strings.TrimSpace(scope.Bucket), + StoragePrefix: strings.TrimSpace(scope.Prefix), + }, + } + + if _, err := config.UpdateRemote(remote, remoteGen3); err != nil { + return fmt.Errorf("failed to update remote config: %w", err) + } + if err := saveCredential(); err != nil { + return fmt.Errorf("failed to configure/update Gen3 profile: %w", err) + } + if err := configureRepoRemote(remoteName, endpoint); err != nil { + return err + } + return nil +} + +func persistLocalRemote(remoteName, url, organization, project string, scope gitrepo.ResolvedBucketScope) (*config.Config, error) { + remoteSelect := config.RemoteSelect{ + Local: &config.LocalRemote{ + BaseURL: url, + ProjectID: project, + Bucket: strings.TrimSpace(scope.Bucket), + Organization: organization, + StoragePrefix: strings.TrimSpace(scope.Prefix), + }, + } + + newConfig, err := config.UpdateRemote(config.Remote(remoteName), remoteSelect) + if err != nil { + return nil, err + } + if err := configureRepoRemote(remoteName, url); err != nil { + return nil, err + } + return newConfig, nil +} + +func configureRepoRemote(remoteName, endpoint string) error { + if err := gitrepo.SetRemoteLFSURL(remoteName, endpoint); err != nil { + return fmt.Errorf("failed to configure lfs url for remote %q: %w", remoteName, err) + } + if err := gitrepo.ConfigureCredentialHelperForRepo(); err != nil { + return fmt.Errorf("failed to configure git credential helper: %w", err) + } + return nil +} + +func configureLocalBasicAuth(remoteName, username, password string) error { + username = strings.TrimSpace(username) + password = strings.TrimSpace(password) + if username == "" && password == "" { + return nil + } + if username == "" || password == "" { + return fmt.Errorf("both --username and --password are required when configuring local basic auth") + } + if err := gitrepo.SetRemoteBasicAuth(remoteName, username, password); err != nil { + return fmt.Errorf("failed to configure local basic auth for remote %q: %w", remoteName, err) + } + return nil +} diff --git a/cmd/remote/add/scope.go b/cmd/remote/add/scope.go new file mode 100644 index 00000000..cc1ef92f --- /dev/null +++ b/cmd/remote/add/scope.go @@ -0,0 +1,24 @@ +package add + +import ( + "fmt" + "strings" +) + +func parseScopeArg(raw string) (string, string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", "", fmt.Errorf("organization/project scope is required") + } + + parts := strings.Split(raw, "/") + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid scope %q: expected organization/project", raw) + } + organization := strings.TrimSpace(parts[0]) + project := strings.TrimSpace(parts[1]) + if organization == "" || project == "" { + return "", "", fmt.Errorf("invalid scope %q: expected organization/project", raw) + } + return organization, project, nil +} diff --git a/cmd/remote/list.go b/cmd/remote/list.go index 8723bc97..a7a3035d 100644 --- a/cmd/remote/list.go +++ b/cmd/remote/list.go @@ -3,11 +3,22 @@ package remote import ( "fmt" + calyprconf "github.com/calypr/calypr-cli/conf" + "github.com/calypr/calypr-cli/credentials" "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" + "github.com/calypr/git-drs/internal/remoteruntime" "github.com/spf13/cobra" ) +var ( + loadConfig = config.LoadConfig + loadProfileCredential = func(profile string) (*calyprconf.Credential, error) { + return calyprconf.NewConfigure(drslog.GetLogger()).Load(profile) + } + ensureValidCredential = credentials.EnsureValidCredential +) + var ListCmd = &cobra.Command{ Use: "list", Short: "List DRS repos", @@ -20,7 +31,7 @@ var ListCmd = &cobra.Command{ }, RunE: func(cmd *cobra.Command, args []string) error { logg := drslog.GetLogger() - cfg, err := config.LoadConfig() + cfg, err := loadConfig() if err != nil { logg.Debug(fmt.Sprintf("Error loading config: %s", err)) return err @@ -53,6 +64,16 @@ var ListCmd = &cobra.Command{ } fmt.Printf("%s %-10s %-8s %s\n", marker, name, remoteType, endpoint) + if remoteSelect.Gen3 != nil { + cred, err := loadProfileCredential(string(name)) + if err != nil { + logg.Warn(fmt.Sprintf("remote %s credential check skipped: %v", name, err)) + continue + } + if err := ensureValidCredential(cmd.Context(), cred, logg); err != nil { + logg.Warn(remoteruntime.WrapCredentialValidationError(string(name), err).Error()) + } + } } return nil }, diff --git a/cmd/remote/remote_test.go b/cmd/remote/remote_test.go index b03b3121..c1d33aab 100644 --- a/cmd/remote/remote_test.go +++ b/cmd/remote/remote_test.go @@ -1,9 +1,14 @@ package remote import ( + "context" + "log/slog" + "os/exec" "testing" + "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/testutils" + syconf "github.com/calypr/syfon/client/config" "github.com/stretchr/testify/assert" ) @@ -21,6 +26,22 @@ func TestRemoteListRun(t *testing.T) { tmpDir := testutils.SetupTestGitRepo(t) testutils.CreateDefaultTestConfig(t, tmpDir) + oldLoadProfileCredential := loadProfileCredential + oldEnsureValidCredential := ensureValidCredential + t.Cleanup(func() { + loadProfileCredential = oldLoadProfileCredential + ensureValidCredential = oldEnsureValidCredential + }) + + loadProfileCredential = func(profile string) (*syconf.Credential, error) { + return &syconf.Credential{Profile: profile, AccessToken: "token", APIEndpoint: "https://example.test"}, nil + } + called := false + ensureValidCredential = func(ctx context.Context, cred *syconf.Credential, _ *slog.Logger) error { + called = true + return nil + } + // Capture stdout output := testutils.CaptureStdout(t, func() { err := ListCmd.RunE(ListCmd, []string{}) @@ -29,6 +50,7 @@ func TestRemoteListRun(t *testing.T) { assert.Contains(t, output, "origin") assert.Contains(t, output, "gen3") + assert.True(t, called, "expected remote list to validate the configured credential") } func TestRemoteSetArgs(t *testing.T) { @@ -44,3 +66,89 @@ func TestRemoteSetArgs(t *testing.T) { err = SetCmd.Args(SetCmd, []string{"origin", "extra"}) assert.Error(t, err) } + +func TestRemoteRemoveArgs(t *testing.T) { + err := RemoveCmd.Args(RemoveCmd, []string{"origin"}) + assert.NoError(t, err) + + err = RemoveCmd.Args(RemoveCmd, []string{}) + assert.Error(t, err) + + err = RemoveCmd.Args(RemoveCmd, []string{"origin", "extra"}) + assert.Error(t, err) +} + +func TestRemoteRemoveRunReassignsDefaultAndCleansKeys(t *testing.T) { + tmpDir := testutils.SetupTestGitRepo(t) + testutils.CreateTestConfig(t, tmpDir, &config.Config{ + DefaultRemote: "origin", + Remotes: map[config.Remote]config.RemoteSelect{ + "origin": { + Gen3: &config.Gen3Remote{ + Endpoint: "https://origin.example", + ProjectID: "origin-proj", + Bucket: "origin-bucket", + }, + }, + "backup": { + Gen3: &config.Gen3Remote{ + Endpoint: "https://backup.example", + ProjectID: "backup-proj", + Bucket: "backup-bucket", + }, + }, + }, + }) + + for _, args := range [][]string{ + {"config", "drs.remote.origin.token", "token"}, + {"config", "drs.remote.origin.username", "alice"}, + {"config", "drs.remote.origin.password", "secret"}, + {"config", "remote.origin.lfsurl", "https://origin.example/info/lfs"}, + } { + cmd := exec.Command("git", args...) + cmd.Dir = tmpDir + err := cmd.Run() + assert.NoError(t, err) + } + + err := RemoveCmd.RunE(RemoveCmd, []string{"origin"}) + assert.NoError(t, err) + + cfg, err := config.LoadConfig() + assert.NoError(t, err) + assert.NotContains(t, cfg.Remotes, config.Remote("origin")) + assert.Equal(t, config.Remote("backup"), cfg.DefaultRemote) + + for _, key := range []string{ + "drs.remote.origin.type", + "drs.remote.origin.endpoint", + "drs.remote.origin.project", + "drs.remote.origin.bucket", + "drs.remote.origin.token", + "drs.remote.origin.username", + "drs.remote.origin.password", + "remote.origin.lfsurl", + } { + val, err := exec.Command("git", "config", "--get", key).CombinedOutput() + assert.Empty(t, string(val)) + assert.Error(t, err) + } +} + +func TestRemoteRemoveRunClearsDefaultWhenLastRemoteRemoved(t *testing.T) { + tmpDir := testutils.SetupTestGitRepo(t) + testutils.CreateDefaultTestConfig(t, tmpDir) + + err := RemoveCmd.RunE(RemoveCmd, []string{"origin"}) + assert.NoError(t, err) + + cfg, err := config.LoadConfig() + assert.NoError(t, err) + assert.Empty(t, cfg.Remotes) + assert.Equal(t, config.Remote(""), cfg.DefaultRemote) + + val, err := exec.Command("git", "config", "--get", "drs.default-remote").CombinedOutput() + assert.Empty(t, string(val)) + assert.Error(t, err) +} diff --git a/cmd/remote/remove.go b/cmd/remote/remove.go new file mode 100644 index 00000000..a5f5dbdc --- /dev/null +++ b/cmd/remote/remove.go @@ -0,0 +1,59 @@ +package remote + +import ( + "fmt" + "sort" + + "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/drslog" + "github.com/spf13/cobra" +) + +var RemoveCmd = &cobra.Command{ + Use: "remove ", + Aliases: []string{"rm"}, + Short: "Remove a DRS remote", + Long: "Remove a configured DRS remote and repair the default remote if needed.", + Args: func(cmd *cobra.Command, args []string) error { + if len(args) != 1 { + cmd.SilenceUsage = false + return fmt.Errorf("error: requires exactly 1 argument (remote name), received %d\n\nUsage: %s\n\nRun 'git drs remote list' to see available remotes or 'git drs remote remove --help' for more details", len(args), cmd.UseLine()) + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + remoteName := config.Remote(args[0]) + logger := drslog.GetLogger() + + cfg, err := config.LoadConfig() + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + if _, ok := cfg.Remotes[remoteName]; !ok { + availableRemotes := make([]string, 0, len(cfg.Remotes)) + for name := range cfg.Remotes { + availableRemotes = append(availableRemotes, string(name)) + } + sort.Strings(availableRemotes) + return fmt.Errorf( + "remote '%s' not found.\nAvailable remotes: %v", + remoteName, + availableRemotes, + ) + } + + updated, err := config.RemoveRemote(remoteName) + if err != nil { + return fmt.Errorf("failed to remove remote: %w", err) + } + + if updated.DefaultRemote == "" { + logger.Debug(fmt.Sprintf("Removed remote %s; no default remote remains", remoteName)) + return nil + } + + logger.Debug(fmt.Sprintf("Removed remote %s; default remote is now %s", remoteName, updated.DefaultRemote)) + return nil + }, +} diff --git a/cmd/remote/root.go b/cmd/remote/root.go index 7d865720..45a1963d 100644 --- a/cmd/remote/root.go +++ b/cmd/remote/root.go @@ -14,5 +14,6 @@ var Cmd = &cobra.Command{ func init() { Cmd.AddCommand(add.Cmd) Cmd.AddCommand(ListCmd) + Cmd.AddCommand(RemoveCmd) Cmd.AddCommand(SetCmd) } diff --git a/cmd/rm/main.go b/cmd/rm/main.go new file mode 100644 index 00000000..a58124d1 --- /dev/null +++ b/cmd/rm/main.go @@ -0,0 +1,58 @@ +package rm + +import ( + "context" + "fmt" + "os/exec" + "path/filepath" + "strings" + + "github.com/calypr/git-drs/internal/drslog" + "github.com/calypr/git-drs/internal/lfs" + "github.com/spf13/cobra" +) + +var runCommand = func(name string, args ...string) error { + cmd := exec.Command(name, args...) + return cmd.Run() +} + +var Cmd = &cobra.Command{ + Use: "rm ...", + Short: "Remove tracked git-drs files", + Args: cobra.MinimumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return run(cmd.Context(), args) + }, +} + +func run(ctx context.Context, args []string) error { + tracked, err := lfs.GetTrackedLfsFiles(drslog.GetLogger()) + if err != nil { + return err + } + + type removal struct { + path string + oid string + } + planned := make([]removal, 0, len(args)) + for _, raw := range args { + path := filepath.ToSlash(filepath.Clean(raw)) + info, ok := tracked[path] + if !ok || strings.TrimSpace(info.Oid) == "" { + return fmt.Errorf("%s is not a tracked git-drs/LFS file", raw) + } + planned = append(planned, removal{path: path, oid: "sha256:" + strings.TrimPrefix(strings.TrimSpace(info.Oid), "sha256:")}) + } + + gitArgs := []string{"rm", "--"} + for _, item := range planned { + gitArgs = append(gitArgs, item.path) + } + if err := runCommand("git", gitArgs...); err != nil { + return err + } + + return nil +} diff --git a/cmd/rm/main_test.go b/cmd/rm/main_test.go new file mode 100644 index 00000000..16c51f28 --- /dev/null +++ b/cmd/rm/main_test.go @@ -0,0 +1,54 @@ +package rm + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "testing" +) + +func TestRunRemovesTrackedFile(t *testing.T) { + repo := t.TempDir() + runGitCmd(t, repo, "init") + runGitCmd(t, repo, "config", "user.email", "test@example.com") + runGitCmd(t, repo, "config", "user.name", "Test User") + runGitCmd(t, repo, "config", "filter.drs.clean", "cat") + runGitCmd(t, repo, "config", "filter.drs.smudge", "cat") + runGitCmd(t, repo, "config", "filter.drs.process", "cat") + runGitCmd(t, repo, "config", "filter.drs.required", "false") + + if err := os.WriteFile(filepath.Join(repo, ".gitattributes"), []byte("*.dat filter=drs diff=drs merge=drs -text\n"), 0o644); err != nil { + t.Fatalf("write .gitattributes: %v", err) + } + path := filepath.Join(repo, "data.dat") + if err := os.WriteFile(path, []byte("version https://git-lfs.github.com/spec/v1\noid sha256:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nsize 12\n"), 0o644); err != nil { + t.Fatalf("write pointer file: %v", err) + } + runGitCmd(t, repo, "add", ".") + runGitCmd(t, repo, "commit", "-m", "add pointer") + + oldWD, _ := os.Getwd() + if err := os.Chdir(repo); err != nil { + t.Fatalf("chdir repo: %v", err) + } + t.Cleanup(func() { _ = os.Chdir(oldWD) }) + + if err := run(context.Background(), []string{"data.dat"}); err != nil { + t.Fatalf("run returned error: %v", err) + } + + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Fatalf("expected file removed from worktree, stat err=%v", err) + } +} + +func runGitCmd(t *testing.T, dir string, args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("git %v failed: %v\n%s", args, err, string(out)) + } +} diff --git a/cmd/root.go b/cmd/root.go index ddfc95ac..320592f8 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -5,28 +5,23 @@ import ( "github.com/calypr/git-drs/cmd/addurl" "github.com/calypr/git-drs/cmd/bucket" "github.com/calypr/git-drs/cmd/clean" + "github.com/calypr/git-drs/cmd/copyrecords" deleteCmd "github.com/calypr/git-drs/cmd/delete" "github.com/calypr/git-drs/cmd/deleteproject" - - "github.com/calypr/git-drs/cmd/download" - "github.com/calypr/git-drs/cmd/fetch" "github.com/calypr/git-drs/cmd/filter" "github.com/calypr/git-drs/cmd/initialize" "github.com/calypr/git-drs/cmd/install" - - "github.com/calypr/git-drs/cmd/list" "github.com/calypr/git-drs/cmd/lsfiles" + "github.com/calypr/git-drs/cmd/ping" "github.com/calypr/git-drs/cmd/precommit" - "github.com/calypr/git-drs/cmd/prepush" "github.com/calypr/git-drs/cmd/pull" "github.com/calypr/git-drs/cmd/push" "github.com/calypr/git-drs/cmd/query" "github.com/calypr/git-drs/cmd/remote" + "github.com/calypr/git-drs/cmd/rm" "github.com/calypr/git-drs/cmd/smudge" "github.com/calypr/git-drs/cmd/track" "github.com/calypr/git-drs/cmd/untrack" - - "github.com/calypr/git-drs/cmd/upload" "github.com/calypr/git-drs/cmd/version" "github.com/spf13/cobra" ) @@ -41,20 +36,20 @@ var RootCmd = &cobra.Command{ func init() { // Hide internal commands precommit.Cmd.Hidden = true - prepush.Cmd.Hidden = true filter.Cmd.Hidden = true RootCmd.AddCommand(initialize.Cmd) RootCmd.AddCommand(version.Cmd) + RootCmd.AddCommand(ping.Cmd) RootCmd.AddCommand(filter.Cmd) RootCmd.AddCommand(clean.Cmd) + RootCmd.AddCommand(copyrecords.Cmd) RootCmd.AddCommand(smudge.Cmd) RootCmd.AddCommand(remote.Cmd) - RootCmd.AddCommand(fetch.Cmd) + RootCmd.AddCommand(rm.Cmd) RootCmd.AddCommand(pull.Cmd) RootCmd.AddCommand(push.Cmd) RootCmd.AddCommand(precommit.Cmd) - RootCmd.AddCommand(prepush.Cmd) RootCmd.AddCommand(addref.Cmd) RootCmd.AddCommand(addurl.Cmd) RootCmd.AddCommand(deleteCmd.Cmd) @@ -63,10 +58,7 @@ func init() { RootCmd.AddCommand(bucket.Cmd) RootCmd.AddCommand(track.Cmd) RootCmd.AddCommand(untrack.Cmd) - RootCmd.AddCommand(list.Cmd) RootCmd.AddCommand(lsfiles.Cmd) - RootCmd.AddCommand(upload.Cmd) - RootCmd.AddCommand(download.Cmd) RootCmd.AddCommand(install.Cmd) RootCmd.CompletionOptions.HiddenDefaultCmd = true diff --git a/cmd/smudge/main.go b/cmd/smudge/main.go index 9964769c..faebd995 100644 --- a/cmd/smudge/main.go +++ b/cmd/smudge/main.go @@ -7,9 +7,10 @@ import ( "os" "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/drsfilter" "github.com/calypr/git-drs/internal/drslog" - "github.com/calypr/git-drs/internal/drsremote" + internalfilter "github.com/calypr/git-drs/internal/filter" + "github.com/calypr/git-drs/internal/remoteruntime" + internaltransfer "github.com/calypr/git-drs/internal/transfer" "github.com/spf13/cobra" ) @@ -51,19 +52,24 @@ func runSmudge(cmd *cobra.Command, args []string) error { if err != nil { if errors.Is(err, config.ErrNoDefaultRemote) { logger.Debug("smudge: no default remote configured; passing through pointer", "pathname", pathname) - return drsfilter.SmudgeContent(ctx, pathname, os.Stdin, os.Stdout, logger, nil) + return internalfilter.SmudgeContent(ctx, pathname, os.Stdin, os.Stdout, logger, nil) } return fmt.Errorf("smudge: get default remote: %w", err) } - drsCtx, err := cfg.GetRemoteClient(remote, logger) + drsCtx, err := remoteruntime.New(cfg, remote, logger) if err != nil { return fmt.Errorf("smudge: create DRS client: %w", err) } - return drsfilter.SmudgeContent(ctx, pathname, os.Stdin, os.Stdout, logger, func(callCtx context.Context, oid, cachePath string) error { - return drsremote.DownloadToCachePath(callCtx, drsCtx, logger, oid, cachePath) - }) + var downloadFn internalfilter.SmudgeDownloadFunc + if !internalfilter.ShouldSkipSmudge() { + downloadFn = func(callCtx context.Context, oid, cachePath string) error { + return internaltransfer.DownloadToCachePath(callCtx, drsCtx, oid, cachePath) + } + } + + return internalfilter.SmudgeContent(ctx, pathname, os.Stdin, os.Stdout, logger, downloadFn) } func init() {} diff --git a/cmd/track/main.go b/cmd/track/main.go index d0721d79..868b3352 100644 --- a/cmd/track/main.go +++ b/cmd/track/main.go @@ -4,13 +4,13 @@ import ( "context" "fmt" - "github.com/calypr/git-drs/internal/drstrack" + "github.com/calypr/git-drs/internal/gitrepo" "github.com/spf13/cobra" ) var ( - gitLFSTrackPatterns = drstrack.TrackPatterns - gitLFSListPatterns = drstrack.ListTrackedPatterns + gitLFSTrackPatterns = gitrepo.TrackPatterns + gitLFSListPatterns = gitrepo.ListTrackedPatterns ) var Cmd = NewCommand() @@ -56,7 +56,9 @@ func runTrack(cmd *cobra.Command, args []string) error { } if out != "" { - _, _ = fmt.Fprint(cmd.OutOrStdout(), out) + if _, err := fmt.Fprint(cmd.OutOrStdout(), out); err != nil { + return fmt.Errorf("write track output: %w", err) + } } return nil } diff --git a/cmd/untrack/main.go b/cmd/untrack/main.go index 5e498f51..fa22f4e5 100644 --- a/cmd/untrack/main.go +++ b/cmd/untrack/main.go @@ -4,11 +4,11 @@ import ( "context" "fmt" - "github.com/calypr/git-drs/internal/drstrack" + "github.com/calypr/git-drs/internal/gitrepo" "github.com/spf13/cobra" ) -var gitLFSUntrackPatterns = drstrack.UntrackPatterns +var gitLFSUntrackPatterns = gitrepo.UntrackPatterns var Cmd = NewCommand() @@ -48,7 +48,9 @@ func runUntrack(cmd *cobra.Command, args []string) error { } if out != "" { - _, _ = fmt.Fprint(cmd.OutOrStdout(), out) + if _, err := fmt.Fprint(cmd.OutOrStdout(), out); err != nil { + return fmt.Errorf("write untrack output: %w", err) + } } return nil } diff --git a/cmd/upload/main.go b/cmd/upload/main.go deleted file mode 100644 index 992b8f52..00000000 --- a/cmd/upload/main.go +++ /dev/null @@ -1,99 +0,0 @@ -package upload - -import ( - "fmt" - "os" - "path/filepath" - "strings" - - "github.com/calypr/git-drs/internal/common" - "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/drslog" - "github.com/calypr/git-drs/internal/drsobject" - "github.com/calypr/git-drs/internal/drsremote" - syupload "github.com/calypr/syfon/client/transfer/upload" - "github.com/spf13/cobra" -) - -var remote string - -// Cmd line declaration -var Cmd = &cobra.Command{ - Use: "upload ", - Short: "Upload a file to a DRS server", - Long: "Upload a file to a DRS server, without creating an LFS pointer", - Args: cobra.MinimumNArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - - logger := drslog.GetLogger() - - config, err := config.LoadConfig() - if err != nil { - return err - } - - remoteName, err := config.GetRemoteOrDefault(remote) - if err != nil { - logger.Error(fmt.Sprintf("Error getting remote: %v", err)) - return err - } - - client, err := config.GetRemoteClient(remoteName, logger) - if err != nil { - return err - } - - remoteConfig := config.GetRemote(remoteName) - organization := "" - project := "" - storagePrefix := "" - bucketName := "" - if remoteConfig != nil { - organization = remoteConfig.GetOrganization() - project = remoteConfig.GetProjectId() - storagePrefix = remoteConfig.GetStoragePrefix() - bucketName = remoteConfig.GetBucketName() - } - - for _, src := range args { - if s, err := os.Stat(src); err != nil { - logger.Error(fmt.Sprintf("Error stating file %s: %v", src, err)) - return err - } else if s.IsDir() { - logger.Error(fmt.Sprintf("Skipping directory %s", src)) - continue - } else { - sha256, err := common.CalculateFileSHA256(src) - if err != nil { - logger.Error(fmt.Sprintf("Error calculating SHA256 for file %s: %v", src, err)) - return err - } - - objs, err := drsremote.ObjectsByHashForScope(cmd.Context(), client, sha256) - if err != nil || len(objs) == 0 { - did := sha256 - name := filepath.Base(src) - drsObj, err := drsobject.BuildWithPrefix(name, sha256, s.Size(), did, bucketName, organization, project, storagePrefix) - if err != nil { - return fmt.Errorf("build DRS object for %s: %w", src, err) - } - registered, err := syupload.RegisterFile(cmd.Context(), client.Client.Data(), client.Client.DRS(), drsObj, src, bucketName) - if err != nil { - return fmt.Errorf("error uploading %s: %v", src, err) - } - if registered != nil { - logger.Info(fmt.Sprintf("Successfully uploaded %s to server with DRS ID %s", src, registered.Id)) - } - } else { - logger.Info(fmt.Sprintf("File %s already exists on server with DRS ID %s, skipping upload", src, strings.TrimSpace(objs[0].Id))) - } - } - } - - return nil - }, -} - -func init() { - Cmd.Flags().StringVarP(&remote, "remote", "r", "", "target remote DRS server (default: default_remote)") -} diff --git a/coverage/combined.html b/coverage/combined.html index f00491d3..bcd8b776 100644 --- a/coverage/combined.html +++ b/coverage/combined.html @@ -61,109 +61,135 @@ - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -526,7 +552,7 @@ "os" "path/filepath" - "github.com/calypr/git-drs/internal/cloud" + sycloud "github.com/calypr/syfon/client/cloud" "github.com/spf13/cobra" ) @@ -577,7 +603,7 @@ // printResolvedInfo writes a human-readable summary of resolved Git/LFS and // cloud object information to the command's stdout for user confirmation. -func printResolvedInfo(cmd *cobra.Command, gitCommonDir, lfsRoot string, objectInfo *cloud.ObjectInfo, pathArg string, isTracked bool, sha256 string) error { +func printResolvedInfo(cmd *cobra.Command, gitCommonDir, lfsRoot string, objectInfo *sycloud.ObjectInfo, pathArg string, isTracked bool, sha256 string) error { if _, err := fmt.Fprintf(cmd.OutOrStdout(), ` Resolved Git LFS Object Info ---------------------------- @@ -651,29 +677,34 @@ // NewCommand constructs the Cobra command for the `add-url` subcommand, // wiring usage, argument validation and the RunE handler. -func NewCommand() *cobra.Command { +func NewCommand() *cobra.Command { cmd := &cobra.Command{ - Use: "add-url <cloud-url> [path]", - Short: "Add a file to the Git DRS repo using a cloud object URL", + Use: "add-url <object-url-or-key> [path]", + Short: "Add a file from a provider URL or configured bucket object key", Args: func(cmd *cobra.Command, args []string) error { if len(args) < 1 || len(args) > 2 { - return errors.New("usage: add-url <cloud-url> [path]") + return errors.New("usage: add-url <object-url-or-key> [path]") } return nil }, RunE: runAddURL, } - addFlags(cmd) + addFlags(cmd) return cmd } // addFlags registers optional expected SHA256 checksum. -func addFlags(cmd *cobra.Command) { +func addFlags(cmd *cobra.Command) { cmd.Flags().String( "sha256", "", "Expected SHA256 checksum (optional)", ) + cmd.Flags().String( + "scheme", + "", + "Storage scheme for object-key mode (for example: s3 or gs)", + ) } // runAddURL is the Cobra RunE wrapper that delegates execution to the service. @@ -688,72 +719,132 @@ "fmt" "net/url" "os" + "path" "strings" - "github.com/calypr/git-drs/internal/cloud" + "github.com/calypr/git-drs/internal/gitrepo" + sycloud "github.com/calypr/syfon/client/cloud" "github.com/spf13/cobra" ) // addURLInput holds the parsed CLI state for the add-url command. type addURLInput struct { - objectURL string - path string - sha256 string - objectParams cloud.ObjectParameters + sourceArg string + objectURL string + path string + sha256 string + scheme string } -// parseAddURLInput parses CLI args and flags into an addURLInput and constructs -// cloud.ObjectParameters for metadata inspection. -func parseAddURLInput(cmd *cobra.Command, args []string) (addURLInput, error) { - objectURL := args[0] +// parseAddURLInput parses CLI args and flags into an addURLInput. +func parseAddURLInput(cmd *cobra.Command, args []string) (addURLInput, error) { + sourceArg := strings.TrimSpace(args[0]) - pathArg, err := resolvePathArg(objectURL, args) + pathArg, err := resolvePathArg(sourceArg, args) if err != nil { return addURLInput{}, err } - sha256Param, err := cmd.Flags().GetString("sha256") + sha256Param, err := cmd.Flags().GetString("sha256") if err != nil { return addURLInput{}, fmt.Errorf("read flag sha256: %w", err) } + scheme, err := cmd.Flags().GetString("scheme") + if err != nil { + return addURLInput{}, fmt.Errorf("read flag scheme: %w", err) + } - return addURLInput{ - objectURL: objectURL, + return addURLInput{ + sourceArg: sourceArg, path: pathArg, sha256: sha256Param, - objectParams: cloud.ObjectParameters{ - ObjectURL: objectURL, - S3Region: firstNonEmpty(os.Getenv("AWS_REGION"), os.Getenv("AWS_DEFAULT_REGION"), os.Getenv("TEST_BUCKET_REGION")), - S3Endpoint: firstNonEmpty(os.Getenv("AWS_ENDPOINT_URL_S3"), os.Getenv("AWS_ENDPOINT_URL"), os.Getenv("TEST_BUCKET_ENDPOINT")), - S3AccessKey: firstNonEmpty(os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("TEST_BUCKET_ACCESS_KEY")), - S3SecretKey: firstNonEmpty(os.Getenv("AWS_SECRET_ACCESS_KEY"), os.Getenv("TEST_BUCKET_SECRET_KEY")), - SHA256: sha256Param, - DestinationPath: pathArg, - }, + scheme: strings.ToLower(strings.TrimSpace(scheme)), }, nil } // resolvePathArg returns the explicit destination path argument when provided, -// otherwise derives the worktree path from the given cloud URL path component. -func resolvePathArg(objectURL string, args []string) (string, error) { +// otherwise derives the worktree path from the given cloud URL or object key. +func resolvePathArg(sourceArg string, args []string) (string, error) { if len(args) == 2 { return args[1], nil } - u, err := url.Parse(objectURL) + if looksLikeCloudURL(sourceArg) { + u, err := url.Parse(sourceArg) + if err != nil { + return "", err + } + return strings.TrimPrefix(u.Path, "/"), nil + } + return strings.Trim(strings.TrimSpace(sourceArg), "/"), nil +} + +func buildObjectParameters(objectURL, pathArg, sha256 string) sycloud.ObjectParameters { + return sycloud.ObjectParameters{ + ObjectURL: objectURL, + S3Region: firstNonEmpty(os.Getenv("AWS_REGION"), os.Getenv("AWS_DEFAULT_REGION"), os.Getenv("TEST_BUCKET_REGION")), + S3Endpoint: firstNonEmpty(os.Getenv("AWS_ENDPOINT_URL_S3"), os.Getenv("AWS_ENDPOINT_URL"), os.Getenv("TEST_BUCKET_ENDPOINT")), + S3AccessKey: firstNonEmpty(os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("TEST_BUCKET_ACCESS_KEY")), + S3SecretKey: firstNonEmpty(os.Getenv("AWS_SECRET_ACCESS_KEY"), os.Getenv("TEST_BUCKET_SECRET_KEY")), + SHA256: sha256, + DestinationPath: pathArg, + } +} + +func looksLikeCloudURL(raw string) bool { + u, err := url.Parse(strings.TrimSpace(raw)) if err != nil { - return "", err + return false + } + if strings.TrimSpace(u.Scheme) == "" { + return false + } + switch strings.ToLower(strings.TrimSpace(u.Scheme)) { + case "s3", "gs", "gcs", "azblob", "http", "https": + return strings.TrimSpace(u.Host) != "" + default: + return false + } +} + +func resolveObjectURL(input addURLInput, scope gitrepo.ResolvedBucketScope) (string, error) { + if looksLikeCloudURL(input.sourceArg) { + return input.sourceArg, nil } - return strings.TrimPrefix(u.Path, "/"), nil + if input.scheme == "" { + return "", fmt.Errorf("object key mode requires --scheme because local bucket mappings store bucket/prefix but not provider scheme") + } + key := joinObjectKey(scope.Prefix, input.sourceArg) + switch input.scheme { + case "s3": + return fmt.Sprintf("s3://%s/%s", scope.Bucket, key), nil + case "gs", "gcs": + return fmt.Sprintf("gs://%s/%s", scope.Bucket, key), nil + case "azblob", "az": + return "", fmt.Errorf("object key mode for Azure requires a full azblob:// URL because the local mapping does not store account_name") + default: + return "", fmt.Errorf("unsupported --scheme %q (expected s3 or gs, or pass a full object URL)", input.scheme) + } +} + +func joinObjectKey(prefix, key string) string { + parts := make([]string, 0, 2) + if p := strings.Trim(strings.TrimSpace(prefix), "/"); p != "" { + parts = append(parts, p) + } + if k := strings.Trim(strings.TrimSpace(key), "/"); k != "" { + parts = append(parts, k) + } + return path.Join(parts...) } -func firstNonEmpty(values ...string) string { - for _, v := range values { +func firstNonEmpty(values ...string) string { + for _, v := range values { v = strings.TrimSpace(v) - if v != "" { + if v != "" { return v } } - return "" + return "" } @@ -790,16 +881,20 @@ import ( "context" + "crypto/sha256" "fmt" "log/slog" - "os" + "strings" - "github.com/calypr/git-drs/internal/cloud" "github.com/calypr/git-drs/internal/common" "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" - "github.com/calypr/git-drs/internal/drsmap" + "github.com/calypr/git-drs/internal/drsobject" + "github.com/calypr/git-drs/internal/drstrack" "github.com/calypr/git-drs/internal/lfs" + drsapi "github.com/calypr/syfon/apigen/client/drs" + sycloud "github.com/calypr/syfon/client/cloud" + "github.com/google/uuid" "github.com/spf13/cobra" ) @@ -807,7 +902,7 @@ // behavior (logger factory, object inspection, LFS helpers, config loader, etc.). type AddURLService struct { newLogger func(string, bool) (*slog.Logger, error) - inspectObject func(ctx context.Context, input cloud.ObjectParameters) (*cloud.ObjectInfo, error) + inspectObject func(ctx context.Context, input sycloud.ObjectParameters) (*sycloud.ObjectInfo, error) isLFSTracked func(path string) (bool, error) getGitRoots func(ctx context.Context) (string, string, error) gitLFSTrack func(ctx context.Context, path string) (bool, error) @@ -816,131 +911,186 @@ // NewAddURLService constructs an AddURLService populated with production // implementations of its dependencies. -func NewAddURLService() *AddURLService { +func NewAddURLService() *AddURLService { return &AddURLService{ newLogger: drslog.NewLogger, - inspectObject: cloud.InspectObjectForLFS, + inspectObject: sycloud.InspectObject, isLFSTracked: lfs.IsLFSTracked, getGitRoots: lfs.GetGitRootDirectories, - gitLFSTrack: lfs.GitLFSTrackReadOnly, + gitLFSTrack: drstrack.TrackReadOnly, loadConfig: config.LoadConfig, } } -// Run executes the add-url workflow: parse CLI input, inspect the cloud object, +// Run executes the add-url workflow: parse CLI input, resolve the target bucket +// scope, inspect the provider object through the client-owned cloud package, // ensure the LFS object exists in local storage, write a pointer file, update // the pre-commit cache (best-effort), optionally add a tracking entry, and // record the DRS mapping. -func (s *AddURLService) Run(cmd *cobra.Command, args []string) error { +func (s *AddURLService) Run(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - if ctx == nil { + if ctx == nil { ctx = context.Background() } - logger, err := s.newLogger("", false) + logger, err := s.newLogger("", false) if err != nil { return fmt.Errorf("error creating logger: %v", err) } - input, err := parseAddURLInput(cmd, args) + input, err := parseAddURLInput(cmd, args) if err != nil { return err } - objectInfo, err := s.inspectObject(ctx, input.objectParams) + cfg, err := s.loadConfig() if err != nil { - return err + return fmt.Errorf("error getting config: %v", err) } - isTracked, err := s.isLFSTracked(input.path) + remote, err := cfg.GetDefaultRemote() if err != nil { - return fmt.Errorf("check LFS tracking for %s: %w", input.path, err) + return err } - gitCommonDir, lfsRoot, err := s.getGitRoots(ctx) - if err != nil { - return fmt.Errorf("get git root directories: %w", err) + remoteConfig := cfg.GetRemote(remote) + if remoteConfig == nil { + return fmt.Errorf("error getting remote configuration for %s", remote) } - if err := printResolvedInfo(cmd, gitCommonDir, lfsRoot, objectInfo, input.path, isTracked, input.sha256); err != nil { + org, project, scope, err := resolveTargetScope(remoteConfig) + if err != nil { return err } - oid, err := s.ensureLFSObject(ctx, objectInfo, input, lfsRoot) + input.objectURL, err = resolveObjectURL(input, scope) if err != nil { return err } - if err := writePointerFile(input.path, oid, objectInfo.SizeBytes); err != nil { + objectInfo, err := s.inspectObject(ctx, buildObjectParameters(input.objectURL, input.path, input.sha256)) + if err != nil { return err } - if err := updatePrecommitCache(ctx, logger, input.path, oid, input.objectURL); err != nil { - logger.Warn("pre-commit cache update skipped", "error", err) + isTracked, err := s.isLFSTracked(input.path) + if err != nil { + return fmt.Errorf("check LFS tracking for %s: %w", input.path, err) + } + + gitCommonDir, lfsRoot, err := s.getGitRoots(ctx) + if err != nil { + return fmt.Errorf("get git root directories: %w", err) } - if err := maybeTrackLFS(ctx, s.gitLFSTrack, input.path, isTracked); err != nil { + if err := printResolvedInfo(cmd, gitCommonDir, lfsRoot, objectInfo, input.path, isTracked, input.sha256); err != nil { return err } - cfg, err := s.loadConfig() + oid, err := s.ensureLFSObject(ctx, objectInfo, input, lfsRoot) if err != nil { - return fmt.Errorf("error getting config: %v", err) + return err } - remote, err := cfg.GetDefaultRemote() - if err != nil { + if err := writePointerFile(input.path, oid, objectInfo.SizeBytes); err != nil { return err } - remoteConfig := cfg.GetRemote(remote) - if remoteConfig == nil { - return fmt.Errorf("error getting remote configuration for %s", remote) + if err := updatePrecommitCache(ctx, logger, input.path, oid, input.objectURL); err != nil { + logger.Warn("pre-commit cache update skipped", "error", err) } - org, project, scope, err := resolveTargetScope(remoteConfig) - if err != nil { + if err := maybeTrackLFS(ctx, s.gitLFSTrack, input.path, isTracked); err != nil { return err } - builder := common.NewObjectBuilder(scope.Bucket, project) + builder := drsobject.NewBuilder(scope.Bucket, project) builder.Organization = org builder.StoragePrefix = scope.Prefix - file := lfs.LfsFileInfo{ + file := addURLDrsFile{ Name: input.path, Size: objectInfo.SizeBytes, Oid: oid, } - if _, err := drsmap.WriteDrsFile(builder, file, &input.objectURL); err != nil { - return fmt.Errorf("error WriteDrsFile: %v", err) + if _, err := writeAddURLDrsObject(builder, file, input.objectURL); err != nil { + return fmt.Errorf("write local DRS object: %w", err) } - return nil + return nil +} + +type addURLDrsFile struct { + Name string + Size int64 + Oid string +} + +func writeAddURLDrsObject(builder drsobject.Builder, file addURLDrsFile, objectPath string) (*drsapi.DrsObject, error) { + existing, err := drsobject.ReadObject(common.DRS_OBJS_PATH, file.Oid) + var drsObj *drsapi.DrsObject + if err == nil && existing != nil { + drsObj = existing + name := file.Name + drsObj.Name = &name + drsObj.Size = file.Size + } else { + drsID := uuid.NewSHA1(drsobject.UUIDNamespace, []byte(fmt.Sprintf("%s:%s", builder.Project, drsobject.NormalizeOid(file.Oid)))).String() + drsObj, err = builder.Build(file.Name, file.Oid, file.Size, drsID) + if err != nil { + return nil, fmt.Errorf("error building DRS object for oid %s: %w", file.Oid, err) + } + } + + if objectPath != "" { + if drsObj.AccessMethods != nil && len(*drsObj.AccessMethods) > 0 { + am := &(*drsObj.AccessMethods)[0] + am.AccessUrl = &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: objectPath} + } else { + drsObj.AccessMethods = &[]drsapi.AccessMethod{{ + Type: drsapi.AccessMethodTypeS3, + AccessUrl: &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: objectPath}, + }} + } + } + + if err := drsobject.WriteObject(common.DRS_OBJS_PATH, drsObj, file.Oid); err != nil { + return nil, fmt.Errorf("error writing DRS object for oid %s: %w", file.Oid, err) + } + return drsObj, nil } -// ensureLFSObject ensures the LFS object identified by objectInfo exists in the -// repository's LFS storage. If SHA256 is provided, it is trusted and returned. -// Otherwise we create a sentinel object and synthetic OID derived from ETag, -// deferring true checksum validation to first real data use. -func (s *AddURLService) ensureLFSObject(ctx context.Context, objectInfo *cloud.ObjectInfo, input addURLInput, lfsRoot string) (string, error) { +// ensureLFSObject returns the LFS pointer OID to use for the add-url target. +// If SHA256 is provided, it is trusted and returned. Otherwise we derive a +// deterministic placeholder OID from provider identity without writing any +// local LFS object payload. +func (s *AddURLService) ensureLFSObject(ctx context.Context, objectInfo *sycloud.ObjectInfo, input addURLInput, lfsRoot string) (string, error) { _ = ctx + _ = lfsRoot if input.sha256 != "" { return input.sha256, nil } - oid, err := lfs.SyntheticOIDFromETag(objectInfo.ETag) - if err != nil { - return "", err - } - objPath, err := lfs.WriteAddURLSentinelObject(lfsRoot, oid, objectInfo.ETag, input.objectURL) - if err != nil { - return "", err + return placeholderOIDForUnknownSHA(objectInfo.ETag, input.objectURL) +} + +func placeholderOIDForUnknownSHA(etag string, sourceURL string) (string, error) { + e := strings.TrimSpace(strings.Trim(etag, `"`)) + src := strings.TrimSpace(sourceURL) + if e == "" { + return "", fmt.Errorf("etag is required for placeholder oid") } - if _, err := fmt.Fprintf(os.Stderr, "Added add-url sentinel object at %s\n", objPath); err != nil { - return "", fmt.Errorf("stderr write: %w", err) + if src == "" { + return "", fmt.Errorf("source URL is required for placeholder oid") } - return oid, nil + sum := sha256.Sum256([]byte("git-drs-add-url-placeholder:v2\netag=" + e + "\nsource=" + src + "\n")) + return fmt.Sprintf("%x", sum[:]), nil } @@ -957,11 +1107,11 @@ "strings" "time" - gitauth "github.com/calypr/git-drs/internal/auth" + "github.com/calypr/data-client/credentials" "github.com/calypr/git-drs/internal/common" "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/gitrepo" - "github.com/calypr/syfon/client/conf" + conf "github.com/calypr/syfon/client/config" "github.com/spf13/cobra" ) @@ -1195,7 +1345,7 @@ if prof, err := configure.Load(remoteName); err == nil { token = strings.TrimSpace(prof.AccessToken) if token == "" { - if ensureErr := gitauth.EnsureValidCredential(context.Background(), prof, drslog.GetLogger()); ensureErr == nil { + if ensureErr := credentials.EnsureValidCredential(context.Background(), prof, drslog.GetLogger()); ensureErr == nil { _ = configure.Save(prof) token = strings.TrimSpace(prof.AccessToken) } @@ -1326,135 +1476,487 @@ } -