diff --git a/pkg/cmd/dev/dev.go b/pkg/cmd/dev/dev.go index d54ec0ad8..3c02fa14b 100644 --- a/pkg/cmd/dev/dev.go +++ b/pkg/cmd/dev/dev.go @@ -30,6 +30,7 @@ import ( "github.com/kitops-ml/kitops/pkg/lib/constants" "github.com/kitops-ml/kitops/pkg/lib/filesystem" "github.com/kitops-ml/kitops/pkg/lib/filesystem/unpack" + "github.com/kitops-ml/kitops/pkg/lib/filter" "github.com/kitops-ml/kitops/pkg/lib/harness" kfutils "github.com/kitops-ml/kitops/pkg/lib/kitfile" "github.com/kitops-ml/kitops/pkg/lib/repo/util" @@ -192,11 +193,11 @@ func extractModelKitToCache(ctx context.Context, options *DevStartOptions) error } // Add model filter - modelFilter, err := unpack.ParseFilter("model,kitfile") + modelFilter, err := filter.ParseFilter("model,kitfile") if err != nil { return fmt.Errorf("failed to create model filter: %w", err) } - libOpts.FilterConfs = []unpack.FilterConf{*modelFilter} + libOpts.FilterConfs = []filter.FilterConf{*modelFilter} err = unpack.UnpackModelKit(ctx, libOpts) if err != nil { diff --git a/pkg/cmd/list/cmd.go b/pkg/cmd/list/cmd.go index 0842f1d2c..d3e940931 100644 --- a/pkg/cmd/list/cmd.go +++ b/pkg/cmd/list/cmd.go @@ -27,6 +27,7 @@ import ( "github.com/kitops-ml/kitops/pkg/cmd/options" "github.com/kitops-ml/kitops/pkg/lib/constants" + "github.com/kitops-ml/kitops/pkg/lib/filter" "github.com/kitops-ml/kitops/pkg/lib/repo/util" "github.com/kitops-ml/kitops/pkg/output" @@ -76,10 +77,12 @@ kit list registry.example.com/my-namespace/my-model` type listOptions struct { options.NetworkOptions - configHome string - remoteRef *registry.Reference - format string - template string + configHome string + remoteRef *registry.Reference + format string + template string + filters []string + filterConfs []filter.FilterConf } func (opts *listOptions) complete(ctx context.Context, args []string) error { @@ -112,6 +115,13 @@ func (opts *listOptions) complete(ctx context.Context, args []string) error { opts.template = opts.format opts.format = "template" } + for _, fStr := range opts.filters { + fConf, err := filter.ParseFilter(fStr) + if err != nil { + return fmt.Errorf("invalid filter syntax '%s': %w", fStr, err) + } + opts.filterConfs = append(opts.filterConfs, *fConf) + } printConfig(opts) return nil @@ -131,6 +141,7 @@ func ListCommand() *cobra.Command { cmd.Args = cobra.MaximumNArgs(1) cmd.Flags().StringVar(&opts.format, "format", "table", "Output format: table, json, or Go template string") + cmd.Flags().StringArrayVarP(&opts.filters, "filter", "f", []string{}, "Filter what is listed based on type and name. Can be specified multiple times") opts.AddNetworkFlags(cmd) cmd.Flags().SortFlags = false diff --git a/pkg/cmd/list/list.go b/pkg/cmd/list/list.go index 83f137c71..d31b4bd15 100644 --- a/pkg/cmd/list/list.go +++ b/pkg/cmd/list/list.go @@ -22,6 +22,7 @@ import ( "sort" "github.com/kitops-ml/kitops/pkg/lib/constants" + "github.com/kitops-ml/kitops/pkg/lib/filter" "github.com/kitops-ml/kitops/pkg/lib/repo/local" "github.com/kitops-ml/kitops/pkg/lib/repo/util" ) @@ -35,7 +36,7 @@ func listLocalKits(ctx context.Context, opts *listOptions) ([]modelInfo, error) } var allInfo []modelInfo for _, repo := range localRepos { - infos, err := readInfoFromRepo(ctx, repo) + infos, err := readInfoFromRepo(ctx, repo, opts) if err != nil { return nil, err } @@ -45,24 +46,25 @@ func listLocalKits(ctx context.Context, opts *listOptions) ([]modelInfo, error) return allInfo, nil } -func readInfoFromRepo(ctx context.Context, repo local.LocalRepo) ([]modelInfo, error) { +func readInfoFromRepo(ctx context.Context, repo local.LocalRepo, opts *listOptions) ([]modelInfo, error) { var infos []modelInfo manifestDescs := repo.GetAllModels() for _, manifestDesc := range manifestDescs { manifest, config, err := util.GetManifestAndKitfile(ctx, repo, manifestDesc) if err != nil { if errors.Is(err, util.ErrNotAModelKit) { - // Shouldn't happen since this is a local repo, but either way it's not a supported artifact continue } - // Allow artifacts without Kitfiles as all that will be lacking is some metadata; we can still - // describe them if !errors.Is(err, util.ErrNoKitfile) { return nil, err } } + + if !filter.KitfileMatches(config, opts.filterConfs) { + continue + } + tags := repo.GetTags(manifestDesc) - // Strip localhost from repo if present, since we added it repository := util.FormatRepositoryForDisplay(repo.GetRepoName()) if repository == "" { repository = "" diff --git a/pkg/cmd/list/remote.go b/pkg/cmd/list/remote.go index 602918878..8f0230bce 100644 --- a/pkg/cmd/list/remote.go +++ b/pkg/cmd/list/remote.go @@ -22,6 +22,7 @@ import ( "fmt" "github.com/kitops-ml/kitops/pkg/lib/constants/mediatype" + "github.com/kitops-ml/kitops/pkg/lib/filter" "github.com/kitops-ml/kitops/pkg/lib/repo/remote" "github.com/kitops-ml/kitops/pkg/lib/repo/util" @@ -34,16 +35,16 @@ func listRemoteKits(ctx context.Context, opts *listOptions) ([]modelInfo, error) return nil, fmt.Errorf("failed to read repository: %w", err) } if opts.remoteRef.Reference != "" { - info, err := listImageTag(ctx, repo, opts.remoteRef) + info, err := listImageTag(ctx, repo, opts.remoteRef, opts) if info == nil || err != nil { return nil, err } return []modelInfo{*info}, nil } - return listTags(ctx, repo, opts.remoteRef) + return listTags(ctx, repo, opts.remoteRef, opts) } -func listTags(ctx context.Context, repo registry.Repository, ref *registry.Reference) ([]modelInfo, error) { +func listTags(ctx context.Context, repo registry.Repository, ref *registry.Reference, opts *listOptions) ([]modelInfo, error) { var tags []string err := repo.Tags(ctx, "", func(tagsPage []string) error { tags = append(tags, tagsPage...) @@ -60,7 +61,7 @@ func listTags(ctx context.Context, repo registry.Repository, ref *registry.Refer Repository: ref.Repository, Reference: tag, } - info, err := listImageTag(ctx, repo, tagRef) + info, err := listImageTag(ctx, repo, tagRef, opts) if err != nil && !errors.Is(err, util.ErrNotAModelKit) { return nil, err } @@ -72,7 +73,7 @@ func listTags(ctx context.Context, repo registry.Repository, ref *registry.Refer return allInfos, nil } -func listImageTag(ctx context.Context, repo registry.Repository, ref *registry.Reference) (*modelInfo, error) { +func listImageTag(ctx context.Context, repo registry.Repository, ref *registry.Reference, opts *listOptions) (*modelInfo, error) { manifestDesc, err := repo.Resolve(ctx, ref.Reference) if err != nil { return nil, fmt.Errorf("failed to resolve reference %s: %w", ref.Reference, err) @@ -84,6 +85,11 @@ func listImageTag(ctx context.Context, repo registry.Repository, ref *registry.R if _, err := mediatype.ModelFormatForManifest(manifest); err != nil { return nil, nil } + + if !filter.KitfileMatches(config, opts.filterConfs) { + return nil, nil + } + info := &modelInfo{ Repo: ref.Repository, Digest: string(manifestDesc.Digest), diff --git a/pkg/cmd/unpack/cmd.go b/pkg/cmd/unpack/cmd.go index fdec7923b..7600b7674 100644 --- a/pkg/cmd/unpack/cmd.go +++ b/pkg/cmd/unpack/cmd.go @@ -26,6 +26,7 @@ import ( "github.com/kitops-ml/kitops/pkg/lib/completion" "github.com/kitops-ml/kitops/pkg/lib/constants" "github.com/kitops-ml/kitops/pkg/lib/filesystem/unpack" + "github.com/kitops-ml/kitops/pkg/lib/filter" "github.com/kitops-ml/kitops/pkg/lib/repo/util" "github.com/kitops-ml/kitops/pkg/output" @@ -204,16 +205,16 @@ func runCommand(opts *unpackOptions) func(*cobra.Command, []string) error { // Handle deprecated flags by converting to filters conf := opts.unpackConf if conf.unpackKitfile || conf.unpackModels || conf.unpackCode || conf.unpackDatasets || conf.unpackDocs { - deprecatedFilters := unpack.FiltersFromUnpackConf( + deprecatedFilters := filter.FiltersFromUnpackConf( conf.unpackKitfile, conf.unpackModels, conf.unpackCode, conf.unpackDatasets, conf.unpackDocs) libOpts.FilterConfs = deprecatedFilters } else if len(opts.filters) > 0 { // Parse filters using library functionality - for _, filter := range opts.filters { - filterConf, err := unpack.ParseFilter(filter) + for _, fStr := range opts.filters { + filterConf, err := filter.ParseFilter(fStr) if err != nil { - return output.Fatalf("Invalid filter %q: %s", filter, err) + return output.Fatalf("Invalid filter %q: %s", fStr, err) } libOpts.FilterConfs = append(libOpts.FilterConfs, *filterConf) } diff --git a/pkg/lib/filesystem/unpack/core.go b/pkg/lib/filesystem/unpack/core.go index 267ee041c..16ef71826 100644 --- a/pkg/lib/filesystem/unpack/core.go +++ b/pkg/lib/filesystem/unpack/core.go @@ -32,6 +32,7 @@ import ( "github.com/kitops-ml/kitops/pkg/lib/constants" "github.com/kitops-ml/kitops/pkg/lib/constants/mediatype" "github.com/kitops-ml/kitops/pkg/lib/filesystem" + "github.com/kitops-ml/kitops/pkg/lib/filter" "github.com/kitops-ml/kitops/pkg/lib/repo/util" "github.com/kitops-ml/kitops/pkg/output" @@ -107,7 +108,7 @@ func unpackRecursive(ctx context.Context, opts *UnpackOptions, visitedRefs []str return err } } - if shouldUnpackLayer(config, opts.FilterConfs) { + if filter.LayerMatches(config, opts.FilterConfs) { if err := unpackConfig(config, opts.UnpackDir, opts.Overwrite); err != nil { return err } @@ -138,7 +139,7 @@ func unpackRecursive(ctx context.Context, opts *UnpackOptions, visitedRefs []str switch mediaType.Base() { case mediatype.ModelBaseType: entry := config.Model - if !shouldUnpackLayer(entry, opts.FilterConfs) { + if !filter.LayerMatches(entry, opts.FilterConfs) { continue } layerInfo, layerPath = entry.LayerInfo, entry.Path @@ -147,7 +148,7 @@ func unpackRecursive(ctx context.Context, opts *UnpackOptions, visitedRefs []str case mediatype.ModelPartBaseType: entry := config.Model.Parts[modelPartIdx] modelPartIdx += 1 - if !shouldUnpackLayer(entry, opts.FilterConfs) { + if !filter.LayerMatches(entry, opts.FilterConfs) { continue } layerInfo, layerPath = entry.LayerInfo, entry.Path @@ -158,7 +159,7 @@ func unpackRecursive(ctx context.Context, opts *UnpackOptions, visitedRefs []str if layerDesc.Annotations[constants.LayerSubtypeAnnotation] == constants.LayerSubtypePrompt { entry := config.Prompts[promptIdx] promptIdx += 1 - if !shouldUnpackLayer(entry, opts.FilterConfs) { + if !filter.LayerMatches(entry, opts.FilterConfs) { continue } layerInfo, layerPath = entry.LayerInfo, entry.Path @@ -166,7 +167,7 @@ func unpackRecursive(ctx context.Context, opts *UnpackOptions, visitedRefs []str } else { entry := config.Code[codeIdx] codeIdx += 1 - if !shouldUnpackLayer(entry, opts.FilterConfs) { + if !filter.LayerMatches(entry, opts.FilterConfs) { continue } layerInfo, layerPath = entry.LayerInfo, entry.Path @@ -176,7 +177,7 @@ func unpackRecursive(ctx context.Context, opts *UnpackOptions, visitedRefs []str case mediatype.DatasetBaseType: entry := config.DataSets[datasetIdx] datasetIdx += 1 - if !shouldUnpackLayer(entry, opts.FilterConfs) { + if !filter.LayerMatches(entry, opts.FilterConfs) { continue } layerInfo, layerPath = entry.LayerInfo, entry.Path @@ -185,7 +186,7 @@ func unpackRecursive(ctx context.Context, opts *UnpackOptions, visitedRefs []str case mediatype.DocsBaseType: entry := config.Docs[docsIdx] docsIdx += 1 - if !shouldUnpackLayer(entry, opts.FilterConfs) { + if !filter.LayerMatches(entry, opts.FilterConfs) { continue } layerInfo, layerPath = entry.LayerInfo, entry.Path @@ -240,16 +241,16 @@ func unpackParent(ctx context.Context, ref string, optsIn *UnpackOptions, visite opts.ModelRef = parentRef // Unpack only model, ignore code/datasets if len(opts.FilterConfs) == 0 { - modelFilter, err := ParseFilter("model") + modelFilter, err := filter.ParseFilter("model") if err != nil { // Shouldn't happen, ever return fmt.Errorf("failed to parse filter for parent modelkit: %w", err) } - opts.FilterConfs = []FilterConf{*modelFilter} + opts.FilterConfs = []filter.FilterConf{*modelFilter} } else { - var filterConfs []FilterConf + var filterConfs []filter.FilterConf for _, conf := range opts.FilterConfs { - if conf.matchesBaseType("model") { + if slices.Contains(conf.BaseTypes, "model") { // Drop any other base types from this filter conf.BaseTypes = []string{"model"} filterConfs = append(filterConfs, conf) diff --git a/pkg/lib/filesystem/unpack/options.go b/pkg/lib/filesystem/unpack/options.go index f59e84209..d16b0ee31 100644 --- a/pkg/lib/filesystem/unpack/options.go +++ b/pkg/lib/filesystem/unpack/options.go @@ -18,6 +18,7 @@ package unpack import ( "github.com/kitops-ml/kitops/pkg/cmd/options" + "github.com/kitops-ml/kitops/pkg/lib/filter" "oras.land/oras-go/v2/registry" ) @@ -28,7 +29,7 @@ type UnpackOptions struct { ConfigHome string UnpackDir string Filters []string - FilterConfs []FilterConf + FilterConfs []filter.FilterConf ModelRef *registry.Reference Overwrite bool IgnoreExisting bool diff --git a/pkg/lib/filesystem/unpack/filter.go b/pkg/lib/filter/filter.go similarity index 73% rename from pkg/lib/filesystem/unpack/filter.go rename to pkg/lib/filter/filter.go index ffae8d9f0..3f2d19b7c 100644 --- a/pkg/lib/filesystem/unpack/filter.go +++ b/pkg/lib/filter/filter.go @@ -14,7 +14,7 @@ // // SPDX-License-Identifier: Apache-2.0 -package unpack +package filter import ( "fmt" @@ -76,10 +76,10 @@ func ParseFilter(filter string) (*FilterConf, error) { return conf, nil } -// shouldUnpackLayer determines if we should unpack a layer in a Kitfile by matching +// LayerMatches determines if we should unpack a layer in a Kitfile by matching // fields against the filters. Matching is done against path and name (if present). // If filters is empty, we assume everything should be unpacked -func shouldUnpackLayer(layer any, filters []FilterConf) bool { +func LayerMatches(layer any, filters []FilterConf) bool { if len(filters) == 0 { return true } @@ -98,26 +98,26 @@ func shouldUnpackLayer(layer any, filters []FilterConf) bool { } return false case artifact.Model: - return matchesFilters("model", l.Name, filters) || matchesFilters("model", l.Path, filters) + return MatchesFilters("model", l.Name, filters) || MatchesFilters("model", l.Path, filters) case artifact.ModelPart: - return matchesFilters("model", l.Name, filters) || matchesFilters("model", l.Path, filters) + return MatchesFilters("model", l.Name, filters) || MatchesFilters("model", l.Path, filters) case artifact.Docs: // Docs does not have an ID/name field so we can only match on path - return matchesFilters("docs", l.Path, filters) + return MatchesFilters("docs", l.Path, filters) case artifact.DataSet: - return matchesFilters("datasets", l.Name, filters) || matchesFilters("datasets", l.Path, filters) + return MatchesFilters("datasets", l.Name, filters) || MatchesFilters("datasets", l.Path, filters) case artifact.Code: // Code does not have a ID/name field so we can only match on path - return matchesFilters("code", l.Path, filters) + return MatchesFilters("code", l.Path, filters) case artifact.Prompt: // Prompts do not have a ID/name field so we can only match on path - return matchesFilters("prompts", l.Path, filters) + return MatchesFilters("prompts", l.Path, filters) default: return false } } -func matchesFilters(baseType, field string, filterConfs []FilterConf) bool { +func MatchesFilters(baseType, field string, filterConfs []FilterConf) bool { for _, filterConf := range filterConfs { if filterConf.matches(baseType, field) { return true @@ -147,3 +147,48 @@ func FiltersFromUnpackConf(unpackKitfile, unpackModels, unpackCode, unpackDatase } return []FilterConf{filter} } + +// KitfileMatches checks if ANY layer within the Kitfile satisfies the provided filters. +// It is used by commands like 'kit list' to filter whole ModelKits. +func KitfileMatches(kf *artifact.KitFile, filters []FilterConf) bool { + // If no filters are provided, everything matches by default + if len(filters) == 0 { + return true + } + if kf == nil { + return false + } + if kf.Model != nil { + if LayerMatches(*kf.Model, filters) { + return true + } + for _, part := range kf.Model.Parts { + if LayerMatches(part, filters) { + return true + } + } + } + for _, dataset := range kf.DataSets { + if LayerMatches(dataset, filters) { + return true + } + } + for _, code := range kf.Code { + if LayerMatches(code, filters) { + return true + } + } + + for _, prompt := range kf.Prompts { + if LayerMatches(prompt, filters) { + return true + } + } + + for _, doc := range kf.Docs { + if LayerMatches(doc, filters) { + return true + } + } + return false +} diff --git a/pkg/lib/filesystem/unpack/filter_test.go b/pkg/lib/filter/filter_test.go similarity index 69% rename from pkg/lib/filesystem/unpack/filter_test.go rename to pkg/lib/filter/filter_test.go index bed310778..bb8396de3 100644 --- a/pkg/lib/filesystem/unpack/filter_test.go +++ b/pkg/lib/filter/filter_test.go @@ -14,11 +14,12 @@ // // SPDX-License-Identifier: Apache-2.0 -package unpack +package filter import ( "testing" + "github.com/kitops-ml/kitops/pkg/artifact" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -95,7 +96,6 @@ func TestParseFilter_EdgeCases(t *testing.T) { require.NoError(t, err) require.NotNil(t, result) - // Verify base types if tt.expectedTypes != nil { assert.Equal(t, tt.expectedTypes, result.BaseTypes) } @@ -104,7 +104,6 @@ func TestParseFilter_EdgeCases(t *testing.T) { if tt.expectedFilters != nil { assert.Equal(t, tt.expectedFilters, result.Filters) } else { - // If no specific filters expected, should be empty assert.Empty(t, result.Filters) } }) @@ -178,3 +177,79 @@ func TestFiltersFromUnpackConf(t *testing.T) { }) } } + +func TestKitfileMatches(t *testing.T) { + // Setup a mock Kitfile with a few different layers + kf := &artifact.KitFile{ + Model: &artifact.Model{Path: "models/my-model.bin"}, + Prompts: []artifact.Prompt{ + {Path: "prompts/pdf-processing.txt"}, + {Path: "prompts/greeting.txt"}, + }, + DataSets: []artifact.DataSet{ + {Path: "data/train.csv"}, + }, + } + + tests := []struct { + name string + filterStrings []string + expected bool + }{ + { + name: "No filters returns true (default list behavior)", + filterStrings: []string{}, + expected: true, + }, + { + name: "Matches existing model base type", + filterStrings: []string{"model"}, + expected: true, + }, + { + name: "Matches specific prompt path (AND logic)", + filterStrings: []string{"prompts:prompts/pdf-processing.txt"}, + expected: true, + }, + { + name: "Fails specific prompt path that doesn't exist", + filterStrings: []string{"prompts:non-existent.txt"}, + expected: false, + }, + { + name: "Matches OR logic across multiple flags", + filterStrings: []string{"code", "prompts"}, + expected: true, + }, + { + name: "Fails completely mismatched types", + filterStrings: []string{"code", "docs"}, + expected: false, + }, + { + name: "Fails gracefully on nil Kitfile", + filterStrings: []string{"model"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var filters []FilterConf + for _, fStr := range tt.filterStrings { + conf, err := ParseFilter(fStr) + assert.NoError(t, err) + filters = append(filters, *conf) + } + + if tt.name == "Fails gracefully on nil Kitfile" { + result := KitfileMatches(nil, filters) + assert.Equal(t, tt.expected, result) + return + } + + result := KitfileMatches(kf, filters) + assert.Equal(t, tt.expected, result) + }) + } +}