diff --git a/go/internal/feast/registry/grpc.go b/go/internal/feast/registry/grpc.go new file mode 100644 index 00000000000..14756f35bdc --- /dev/null +++ b/go/internal/feast/registry/grpc.go @@ -0,0 +1,147 @@ +package registry + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "strings" + + "github.com/feast-dev/feast/go/protos/feast/core" + registryPb "github.com/feast-dev/feast/go/protos/feast/registry" + "github.com/rs/zerolog/log" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/types/known/emptypb" +) + +type GrpcRegistryStore struct { + project string + client registryPb.RegistryServerClient + conn *grpc.ClientConn +} + +// NewGrpcRegistryStore creates a gRPC-backed registry store. +// +// TLS is enabled when any of the following are true (mirrors Python RemoteRegistryConfig): +// - config.Path uses the "grpcs://" scheme +// - config.IsTls is true +// - config.Cert is set (path to a PEM certificate file) +// +// Without TLS, the connection is insecure. config.Path may be bare "host:port", +// "grpc://host:port", or "grpcs://host:port". +func NewGrpcRegistryStore(config *RegistryConfig, project string) (*GrpcRegistryStore, error) { + target, schemeIsTLS := parseGrpcTarget(config.Path) + useTLS := schemeIsTLS || config.IsTls || config.Cert != "" + + var dialOpts []grpc.DialOption + if useTLS { + tlsCreds, err := buildTLSCredentials(config.Cert) + if err != nil { + return nil, err + } + dialOpts = append(dialOpts, grpc.WithTransportCredentials(tlsCreds)) + } else { + dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + conn, err := grpc.NewClient(target, dialOpts...) + if err != nil { + return nil, err + } + + log.Info().Msgf("Using gRPC Feature Registry: %s", config.Path) + return &GrpcRegistryStore{ + project: project, + client: registryPb.NewRegistryServerClient(conn), + conn: conn, + }, nil +} + +// buildTLSCredentials returns TLS transport credentials. If certPath is non-empty +// the PEM file is loaded as a custom root CA (for self-signed certs), otherwise +// the system certificate pool is used. +func buildTLSCredentials(certPath string) (credentials.TransportCredentials, error) { + if certPath == "" { + return credentials.NewTLS(&tls.Config{}), nil + } + pemBytes, err := os.ReadFile(certPath) + if err != nil { + return nil, fmt.Errorf("grpc registry: reading cert file %q: %w", certPath, err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(pemBytes) { + return nil, fmt.Errorf("grpc registry: no valid PEM certificates found in %q", certPath) + } + return credentials.NewTLS(&tls.Config{RootCAs: pool}), nil +} + +// parseGrpcTarget strips grpc:// or grpcs:// scheme prefixes and returns +// the target address along with whether TLS should be used. +func parseGrpcTarget(path string) (target string, useTLS bool) { + if strings.HasPrefix(path, "https://") { + return strings.TrimPrefix(path, "https://"), true + } + if strings.HasPrefix(path, "http://") { + return strings.TrimPrefix(path, "http://"), false + } + return path, false +} + +func (g *GrpcRegistryStore) getEntity(name string, allowCache bool) (*core.Entity, error) { + return g.client.GetEntity(context.Background(), ®istryPb.GetEntityRequest{ + Name: name, + Project: g.project, + AllowCache: allowCache, + }) +} + +func (g *GrpcRegistryStore) getFeatureView(name string, allowCache bool) (*core.FeatureView, error) { + return g.client.GetFeatureView(context.Background(), ®istryPb.GetFeatureViewRequest{ + Name: name, + Project: g.project, + AllowCache: allowCache, + }) +} + +func (g *GrpcRegistryStore) getSortedFeatureView(name string, allowCache bool) (*core.SortedFeatureView, error) { + return g.client.GetSortedFeatureView(context.Background(), ®istryPb.GetSortedFeatureViewRequest{ + Name: name, + Project: g.project, + AllowCache: allowCache, + }) +} + +func (g *GrpcRegistryStore) getOnDemandFeatureView(name string, allowCache bool) (*core.OnDemandFeatureView, error) { + return g.client.GetOnDemandFeatureView(context.Background(), ®istryPb.GetOnDemandFeatureViewRequest{ + Name: name, + Project: g.project, + AllowCache: allowCache, + }) +} + +func (g *GrpcRegistryStore) getFeatureService(name string, allowCache bool) (*core.FeatureService, error) { + return g.client.GetFeatureService(context.Background(), ®istryPb.GetFeatureServiceRequest{ + Name: name, + Project: g.project, + AllowCache: allowCache, + }) +} + +func (g *GrpcRegistryStore) GetRegistryProto() (*core.Registry, error) { + return g.client.Proto(context.Background(), &emptypb.Empty{}) +} + +func (g *GrpcRegistryStore) UpdateRegistryProto(rp *core.Registry) error { + return &NotImplementedError{FunctionName: "UpdateRegistryProto"} +} + +func (g *GrpcRegistryStore) Teardown() error { + return g.conn.Close() +} + +func (g *GrpcRegistryStore) HasFallback() bool { + return true +} diff --git a/go/internal/feast/registry/grpc_test.go b/go/internal/feast/registry/grpc_test.go new file mode 100644 index 00000000000..9c23bb4279e --- /dev/null +++ b/go/internal/feast/registry/grpc_test.go @@ -0,0 +1,191 @@ +//go:build !integration + +package registry + +import ( + "context" + "net" + "testing" + + "github.com/feast-dev/feast/go/protos/feast/core" + registryPb "github.com/feast-dev/feast/go/protos/feast/registry" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + "google.golang.org/protobuf/types/known/emptypb" +) + +// testRegistryServer is a minimal RegistryServer implementation for testing. +type testRegistryServer struct { + registryPb.UnimplementedRegistryServerServer +} + +func (s *testRegistryServer) GetEntity(_ context.Context, req *registryPb.GetEntityRequest) (*core.Entity, error) { + return &core.Entity{Spec: &core.EntitySpecV2{Name: req.Name}}, nil +} + +func (s *testRegistryServer) GetFeatureView(_ context.Context, req *registryPb.GetFeatureViewRequest) (*core.FeatureView, error) { + return &core.FeatureView{Spec: &core.FeatureViewSpec{Name: req.Name}}, nil +} + +func (s *testRegistryServer) GetSortedFeatureView(_ context.Context, req *registryPb.GetSortedFeatureViewRequest) (*core.SortedFeatureView, error) { + return &core.SortedFeatureView{Spec: &core.SortedFeatureViewSpec{Name: req.Name}}, nil +} + +func (s *testRegistryServer) GetOnDemandFeatureView(_ context.Context, req *registryPb.GetOnDemandFeatureViewRequest) (*core.OnDemandFeatureView, error) { + return &core.OnDemandFeatureView{Spec: &core.OnDemandFeatureViewSpec{Name: req.Name}}, nil +} + +func (s *testRegistryServer) GetFeatureService(_ context.Context, req *registryPb.GetFeatureServiceRequest) (*core.FeatureService, error) { + return &core.FeatureService{Spec: &core.FeatureServiceSpec{Name: req.Name}}, nil +} + +func (s *testRegistryServer) Proto(_ context.Context, _ *emptypb.Empty) (*core.Registry, error) { + return &core.Registry{}, nil +} + +// newTestGrpcStore spins up an in-process gRPC server using bufconn and returns +// a GrpcRegistryStore backed by it, along with a cleanup function. +func newTestGrpcStore(t *testing.T, project string) (*GrpcRegistryStore, func()) { + t.Helper() + + listener := bufconn.Listen(1024 * 1024) + srv := grpc.NewServer() + registryPb.RegisterRegistryServerServer(srv, &testRegistryServer{}) + go func() { + if err := srv.Serve(listener); err != nil && err != grpc.ErrServerStopped { + t.Errorf("bufconn server error: %v", err) + } + }() + + conn, err := grpc.NewClient( + "passthrough:///bufconn", + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { + return listener.DialContext(ctx) + }), + ) + if err != nil { + t.Fatalf("failed to create grpc client: %v", err) + } + + store := &GrpcRegistryStore{ + project: project, + client: registryPb.NewRegistryServerClient(conn), + conn: conn, + } + + cleanup := func() { + conn.Close() + srv.Stop() + listener.Close() + } + return store, cleanup +} + +func TestGrpcGetEntity(t *testing.T) { + store, cleanup := newTestGrpcStore(t, "test_project") + defer cleanup() + + result, err := store.getEntity("test_entity", true) + + assert.Nil(t, err) + assert.Equal(t, "test_entity", result.Spec.Name) +} + +func TestGrpcGetFeatureView(t *testing.T) { + store, cleanup := newTestGrpcStore(t, "test_project") + defer cleanup() + + result, err := store.getFeatureView("test_feature_view", true) + + assert.Nil(t, err) + assert.Equal(t, "test_feature_view", result.Spec.Name) +} + +func TestGrpcGetSortedFeatureView(t *testing.T) { + store, cleanup := newTestGrpcStore(t, "test_project") + defer cleanup() + + result, err := store.getSortedFeatureView("test_sorted_view", true) + + assert.Nil(t, err) + assert.Equal(t, "test_sorted_view", result.Spec.Name) +} + +func TestGrpcGetOnDemandFeatureView(t *testing.T) { + store, cleanup := newTestGrpcStore(t, "test_project") + defer cleanup() + + result, err := store.getOnDemandFeatureView("test_odfv", true) + + assert.Nil(t, err) + assert.Equal(t, "test_odfv", result.Spec.Name) +} + +func TestGrpcGetFeatureService(t *testing.T) { + store, cleanup := newTestGrpcStore(t, "test_project") + defer cleanup() + + result, err := store.getFeatureService("test_feature_service", true) + + assert.Nil(t, err) + assert.Equal(t, "test_feature_service", result.Spec.Name) +} + +func TestGrpcGetRegistryProto(t *testing.T) { + store, cleanup := newTestGrpcStore(t, "test_project") + defer cleanup() + + registry, err := store.GetRegistryProto() + + assert.Nil(t, err) + assert.IsType(t, &core.Registry{}, registry) +} + +func TestGrpcUpdateRegistryProtoNotImplemented(t *testing.T) { + store, cleanup := newTestGrpcStore(t, "test_project") + defer cleanup() + + err := store.UpdateRegistryProto(&core.Registry{}) + + assert.NotNil(t, err) + assert.IsType(t, &NotImplementedError{}, err) + assert.Equal(t, "UpdateRegistryProto", err.(*NotImplementedError).FunctionName) +} + +func TestGrpcHasFallback(t *testing.T) { + store, cleanup := newTestGrpcStore(t, "test_project") + defer cleanup() + + assert.True(t, store.HasFallback()) +} + +func TestGrpcTeardown(t *testing.T) { + store, cleanup := newTestGrpcStore(t, "test_project") + defer cleanup() + + err := store.Teardown() + + assert.Nil(t, err) + assert.Equal(t, connectivity.Shutdown, store.conn.GetState()) +} + +func TestParseGrpcTarget(t *testing.T) { + cases := []struct { + input string + target string + useTLS bool + }{ + {"http://host:50051", "host:50051", false}, + {"https://host:50051", "host:50051", true}, + {"host:50051", "host:50051", false}, + } + for _, c := range cases { + target, useTLS := parseGrpcTarget(c.input) + assert.Equal(t, c.target, target, "input: %s", c.input) + assert.Equal(t, c.useTLS, useTLS, "input: %s", c.input) + } +} diff --git a/go/internal/feast/registry/registry.go b/go/internal/feast/registry/registry.go index a5a23e85976..33b1f9b26f0 100644 --- a/go/internal/feast/registry/registry.go +++ b/go/internal/feast/registry/registry.go @@ -16,12 +16,13 @@ import ( var REGISTRY_SCHEMA_VERSION string = "1" var REGISTRY_STORE_CLASS_FOR_SCHEME map[string]string = map[string]string{ - "gs": "GCSRegistryStore", - "s3": "S3RegistryStore", - "file": "FileRegistryStore", - "http": "HttpRegistryStore", - "https": "HttpRegistryStore", - "": "FileRegistryStore", + "gs": "GCSRegistryStore", + "s3": "S3RegistryStore", + "file": "FileRegistryStore", + "http": "HttpRegistryStore", + "https": "HttpRegistryStore", + "remote": "GrpcRegistryStore", + "": "FileRegistryStore", } /* @@ -277,7 +278,7 @@ func (r *Registry) GetEntity(project string, entityName string) (*model.Entity, } func (r *Registry) GetEntityFromRegistry(entityName string, project string) (*model.Entity, error) { - entityProto, err := r.registryStore.(*HttpRegistryStore).getEntity(entityName, true) + entityProto, err := r.registryStore.(FallbackRegistryStore).getEntity(entityName, true) if err != nil { if errors.IsHTTPNotFoundError(err) { log.Error().Err(err).Msgf("no entity %s found in project %s", entityName, project) @@ -304,7 +305,7 @@ func (r *Registry) GetFeatureView(project string, featureViewName string) (*mode } func (r *Registry) GetFeatureViewFromRegistry(featureViewName string, project string) (*model.FeatureView, error) { - featureViewProto, err := r.registryStore.(*HttpRegistryStore).getFeatureView(featureViewName, true) + featureViewProto, err := r.registryStore.(FallbackRegistryStore).getFeatureView(featureViewName, true) if err != nil { if errors.IsHTTPNotFoundError(err) { log.Error().Err(err).Msgf("no feature view %s found in project %s", featureViewName, project) @@ -331,7 +332,7 @@ func (r *Registry) GetSortedFeatureView(project string, sortedFeatureViewName st } func (r *Registry) GetSortedFeatureViewFromRegistry(sortedFeatureViewName string, project string) (*model.SortedFeatureView, error) { - sortedFeatureViewProto, err := r.registryStore.(*HttpRegistryStore).getSortedFeatureView(sortedFeatureViewName, true) + sortedFeatureViewProto, err := r.registryStore.(FallbackRegistryStore).getSortedFeatureView(sortedFeatureViewName, true) if err != nil { if errors.IsHTTPNotFoundError(err) { log.Error().Err(err).Msgf("no sorted feature view %s found in project %s", sortedFeatureViewName, project) @@ -358,7 +359,7 @@ func (r *Registry) GetFeatureService(project string, featureServiceName string) } func (r *Registry) GetFeatureServiceFromRegistry(featureServiceName string, project string) (*model.FeatureService, error) { - featureServiceProto, err := r.registryStore.(*HttpRegistryStore).getFeatureService(featureServiceName, true) + featureServiceProto, err := r.registryStore.(FallbackRegistryStore).getFeatureService(featureServiceName, true) if err != nil { if errors.IsHTTPNotFoundError(err) { log.Error().Err(err).Msgf("no feature service %s found in project %s", featureServiceName, project) @@ -385,7 +386,7 @@ func (r *Registry) GetOnDemandFeatureView(project string, onDemandFeatureViewNam } func (r *Registry) GetOnDemandFeatureViewFromRegistry(onDemandFeatureViewName string, project string) (*model.OnDemandFeatureView, error) { - onDemandFeatureViewProto, err := r.registryStore.(*HttpRegistryStore).getOnDemandFeatureView(onDemandFeatureViewName, true) + onDemandFeatureViewProto, err := r.registryStore.(FallbackRegistryStore).getOnDemandFeatureView(onDemandFeatureViewName, true) if err != nil { if errors.IsHTTPNotFoundError(err) { log.Error().Err(err).Msgf("no on demand feature view %s found in project %s", onDemandFeatureViewName, project) @@ -418,6 +419,8 @@ func getRegistryStoreFromType(registryStoreType string, registryConfig *Registry return NewHttpRegistryStore(registryConfig, project) case "S3RegistryStore": return NewS3RegistryStore(registryConfig, repoPath), nil + case "GrpcRegistryStore": + return NewGrpcRegistryStore(registryConfig, project) } - return nil, errors.GrpcInternalErrorf("only FileRegistryStore or HttpRegistryStore as a RegistryStore is supported at this moment") + return nil, errors.GrpcInternalErrorf("only FileRegistryStore, HttpRegistryStore, S3RegistryStore, or GrpcRegistryStore is supported at this moment") } diff --git a/go/internal/feast/registry/registrystore.go b/go/internal/feast/registry/registrystore.go index c9c84508b7c..d575c377adf 100644 --- a/go/internal/feast/registry/registrystore.go +++ b/go/internal/feast/registry/registrystore.go @@ -11,3 +11,13 @@ type RegistryStore interface { Teardown() error HasFallback() bool } + +// FallbackRegistryStore is implemented by stores that support per-item fetching +// (HasFallback() == true). This avoids store-specific type casts in registry.go. +type FallbackRegistryStore interface { + getEntity(name string, allowCache bool) (*core.Entity, error) + getFeatureView(name string, allowCache bool) (*core.FeatureView, error) + getSortedFeatureView(name string, allowCache bool) (*core.SortedFeatureView, error) + getOnDemandFeatureView(name string, allowCache bool) (*core.OnDemandFeatureView, error) + getFeatureService(name string, allowCache bool) (*core.FeatureService, error) +} diff --git a/go/internal/feast/registry/repoconfig.go b/go/internal/feast/registry/repoconfig.go index 091c426d920..6caace4b571 100644 --- a/go/internal/feast/registry/repoconfig.go +++ b/go/internal/feast/registry/repoconfig.go @@ -43,6 +43,10 @@ type RegistryConfig struct { Path string `json:"path"` ClientId string `json:"client_id" default:"Unknown"` CacheTtlSeconds int64 `json:"cache_ttl_seconds" default:"600"` + // Cert is the path to a PEM certificate file used when connecting to a gRPC + // registry server over TLS. Mirrors RemoteRegistryConfig.cert in Python. + Cert string `json:"cert"` + IsTls bool `json:"is_tls"` } // NewRepoConfigFromJSON converts a JSON string into a RepoConfig struct and also sets the repo path. @@ -118,6 +122,12 @@ func (r *RepoConfig) GetRegistryConfig() (*RegistryConfig, error) { if value, ok := v.(string); ok { registryConfig.Path = value } + case "registry_type": + if value, ok := v.(string); ok { + if storeType, found := REGISTRY_STORE_CLASS_FOR_SCHEME[value]; found { + registryConfig.RegistryStoreType = storeType + } + } case "registry_store_type": if value, ok := v.(string); ok { registryConfig.RegistryStoreType = value @@ -126,6 +136,14 @@ func (r *RepoConfig) GetRegistryConfig() (*RegistryConfig, error) { if value, ok := v.(string); ok { registryConfig.ClientId = value } + case "cert": + if value, ok := v.(string); ok { + registryConfig.Cert = value + } + case "is_tls": + if value, ok := v.(bool); ok { + registryConfig.IsTls = value + } case "cache_ttl_seconds": // cache_ttl_seconds defaulted to type float64. Ex: "cache_ttl_seconds": 60 in registryConfigMap switch value := v.(type) { diff --git a/go/internal/feast/registry/repoconfig_test.go b/go/internal/feast/registry/repoconfig_test.go index 922438b61c3..2f311a70809 100644 --- a/go/internal/feast/registry/repoconfig_test.go +++ b/go/internal/feast/registry/repoconfig_test.go @@ -193,10 +193,10 @@ func TestGetRegistryConfig_Map(t *testing.T) { // Create a RepoConfig with a map Registry config := &RepoConfig{ Registry: map[string]interface{}{ - "path": "data/registry.db", - "registry_store_type": "local", - "client_id": "test_client_id", - "cache_ttl_seconds": 60, + "path": "data/registry.db", + "registry_type": "file", + "client_id": "test_client_id", + "cache_ttl_seconds": 60, }, } @@ -205,7 +205,7 @@ func TestGetRegistryConfig_Map(t *testing.T) { // Assert that the method correctly processed the map assert.Equal(t, "data/registry.db", registryConfig.Path) - assert.Equal(t, "local", registryConfig.RegistryStoreType) + assert.Equal(t, "FileRegistryStore", registryConfig.RegistryStoreType) assert.Equal(t, int64(60), registryConfig.CacheTtlSeconds) assert.Equal(t, "test_client_id", registryConfig.ClientId) } @@ -385,3 +385,180 @@ func TestNewRepoConfigForScyllaDBFromJSON(t *testing.T) { assert.Equal(t, int(85), int(config.OnlineStore["read_batch_size"].(float64))) assert.Equal(t, int(2), int(config.OnlineStore["table_name_format_version"].(float64))) } + +func TestGetRegistryConfig_RemoteRegistryType(t *testing.T) { + // registry_type: "remote" (Python RemoteRegistryConfig) should map to GrpcRegistryStore + config := &RepoConfig{ + Registry: map[string]interface{}{ + "registry_type": "remote", + "path": "registry-server:50051", + }, + } + + registryConfig, err := config.GetRegistryConfig() + + assert.Nil(t, err) + assert.Equal(t, "GrpcRegistryStore", registryConfig.RegistryStoreType) + assert.Equal(t, "registry-server:50051", registryConfig.Path) +} + +func TestGetRegistryConfig_CertAndIsTls(t *testing.T) { + // cert and is_tls fields should be parsed (mirrors Python RemoteRegistryConfig) + config := &RepoConfig{ + Registry: map[string]interface{}{ + "registry_type": "remote", + "path": "registry-server:50051", + "cert": "/path/to/server.crt", + "is_tls": true, + }, + } + + registryConfig, err := config.GetRegistryConfig() + + assert.Nil(t, err) + assert.Equal(t, "/path/to/server.crt", registryConfig.Cert) + assert.True(t, registryConfig.IsTls) +} + +func TestGetRegistryConfig_IsTlsFalseByDefault(t *testing.T) { + config := &RepoConfig{ + Registry: map[string]interface{}{ + "path": "registry-server:50051", + }, + } + + registryConfig, err := config.GetRegistryConfig() + + assert.Nil(t, err) + assert.False(t, registryConfig.IsTls) + assert.Empty(t, registryConfig.Cert) +} + +// HTTP registry config tests + +func TestGetRegistryConfig_HttpStringPath(t *testing.T) { + // registry: "http://..." as a plain string — path stored, store type left empty + // (scheme-based inference to HttpRegistryStore happens later in getRegistryStoreFromScheme) + config := &RepoConfig{ + Registry: "http://registry-server:8080", + } + + registryConfig, err := config.GetRegistryConfig() + + assert.Nil(t, err) + assert.Equal(t, "http://registry-server:8080", registryConfig.Path) + assert.Empty(t, registryConfig.RegistryStoreType) + assert.Equal(t, defaultClientID, registryConfig.ClientId) + assert.Equal(t, defaultCacheTtlSeconds, registryConfig.CacheTtlSeconds) +} + +func TestGetRegistryConfig_HttpsStringPath(t *testing.T) { + config := &RepoConfig{ + Registry: "https://registry-server:8443", + } + + registryConfig, err := config.GetRegistryConfig() + + assert.Nil(t, err) + assert.Equal(t, "https://registry-server:8443", registryConfig.Path) + assert.Empty(t, registryConfig.RegistryStoreType) +} + +func TestGetRegistryConfig_HttpExplicitStoreType(t *testing.T) { + config := &RepoConfig{ + Registry: map[string]interface{}{ + "registry_type": "http", + "path": "http://registry-server:8080", + }, + } + + registryConfig, err := config.GetRegistryConfig() + + assert.Nil(t, err) + assert.Equal(t, "HttpRegistryStore", registryConfig.RegistryStoreType) + assert.Equal(t, "http://registry-server:8080", registryConfig.Path) +} + +func TestGetRegistryConfig_HttpWithClientId(t *testing.T) { + config := &RepoConfig{ + Registry: map[string]interface{}{ + "registry_type": "https", + "path": "https://registry-server:8443", + "client_id": "my-service", + }, + } + + registryConfig, err := config.GetRegistryConfig() + + assert.Nil(t, err) + assert.Equal(t, "HttpRegistryStore", registryConfig.RegistryStoreType) + assert.Equal(t, "https://registry-server:8443", registryConfig.Path) + assert.Equal(t, "my-service", registryConfig.ClientId) +} + +func TestGetRegistryConfig_HttpWithAllFields(t *testing.T) { + config := &RepoConfig{ + Registry: map[string]interface{}{ + "registry_type": "https", + "path": "https://registry-server:8443", + "client_id": "feast-go-server", + "cache_ttl_seconds": float64(120), + }, + } + + registryConfig, err := config.GetRegistryConfig() + + assert.Nil(t, err) + assert.Equal(t, "HttpRegistryStore", registryConfig.RegistryStoreType) + assert.Equal(t, "https://registry-server:8443", registryConfig.Path) + assert.Equal(t, "feast-go-server", registryConfig.ClientId) + assert.Equal(t, int64(120), registryConfig.CacheTtlSeconds) +} + +func TestGetRegistryConfig_HttpDefaultClientId(t *testing.T) { + // client_id should default to "Unknown" when not specified + config := &RepoConfig{ + Registry: map[string]interface{}{ + "path": "http://registry-server:8080", + }, + } + + registryConfig, err := config.GetRegistryConfig() + + assert.Nil(t, err) + assert.Equal(t, defaultClientID, registryConfig.ClientId) +} + +func TestGetRegistryConfig_HttpFromYaml(t *testing.T) { + dir, err := os.MkdirTemp("", "feature_repo_*") + assert.Nil(t, err) + defer func() { + assert.Nil(t, os.RemoveAll(dir)) + }() + + filePath := filepath.Join(dir, "feature_store.yaml") + data := []byte(` +project: feature_repo +registry: + registry_type: https + path: "https://registry-server:8443" + client_id: "feast-go-server" + cache_ttl_seconds: 300 +provider: local +online_store: + type: redis + connection_string: "localhost:6379" +`) + err = os.WriteFile(filePath, data, 0666) + assert.Nil(t, err) + + config, err := NewRepoConfigFromFile(dir) + assert.Nil(t, err) + + registryConfig, err := config.GetRegistryConfig() + assert.Nil(t, err) + assert.Equal(t, "HttpRegistryStore", registryConfig.RegistryStoreType) + assert.Equal(t, "https://registry-server:8443", registryConfig.Path) + assert.Equal(t, "feast-go-server", registryConfig.ClientId) + assert.Equal(t, int64(300), registryConfig.CacheTtlSeconds) +}