From f863c2b50d2baa184fe0da7fdf9acafcae69ade1 Mon Sep 17 00:00:00 2001 From: Christie Ellks Date: Thu, 14 May 2026 17:23:45 -0700 Subject: [PATCH 1/8] refactor spanner client --- internal/server/spanner/client.go | 132 ++++--- internal/server/spanner/client_normalized.go | 38 -- internal/server/spanner/client_selector.go | 9 +- internal/server/spanner/executor.go | 352 +++++++++++++++++++ internal/server/spanner/query.go | 351 ++---------------- internal/server/spanner/query_normalized.go | 8 +- internal/server/spanner/timestamp_test.go | 4 +- test/setup.go | 13 +- 8 files changed, 462 insertions(+), 445 deletions(-) delete mode 100644 internal/server/spanner/client_normalized.go create mode 100644 internal/server/spanner/executor.go diff --git a/internal/server/spanner/client.go b/internal/server/spanner/client.go index e3ca09891..609933609 100644 --- a/internal/server/spanner/client.go +++ b/internal/server/spanner/client.go @@ -19,8 +19,6 @@ import ( "context" "fmt" "log/slog" - "sync" - "sync/atomic" "cloud.google.com/go/spanner" pb "github.com/datacommonsorg/mixer/internal/proto" @@ -58,35 +56,38 @@ type SpannerClient interface { Close() } -// spannerDatabaseClient encapsulates the Spanner client that directly interacts with the Spanner database. -type spannerDatabaseClient struct { - client *spanner.Client - timestamp atomic.Int64 - ticker Ticker - stopCh chan struct{} - startOnce sync.Once - stopOnce sync.Once - wg sync.WaitGroup - - // For mocking in tests. - updateTimestamp func(context.Context) error +// NormalizedObservationProvider defines the subset of methods supported by the normalized schema path. +type NormalizedObservationProvider interface { + GetObservations(ctx context.Context, variables []string, entities []string) ([]*Observation, error) + CheckVariableExistence(ctx context.Context, variables []string, entities []string) ([][]string, error) + GetObservationsContainedInPlace(ctx context.Context, variables []string, containedInPlace *v2.ContainedInPlace) ([]*Observation, error) + GetSdmxObservations(ctx context.Context, req *pb.SdmxDataQuery) (*pb.SdmxDataResult, error) } -// newSpannerDatabaseClient creates a new spannerDatabaseClient. -func newSpannerDatabaseClient(client *spanner.Client) (*spannerDatabaseClient, error) { - sc := &spannerDatabaseClient{ - client: client, - } +// defaultSpannerClient encapsulates the Spanner client that directly interacts with the Spanner database. +type defaultSpannerClient struct { + exec *SpannerExecutor +} + +// newDefaultSpannerClient creates a new defaultSpannerClient. +func newDefaultSpannerClient(exec *SpannerExecutor) *defaultSpannerClient { + return &defaultSpannerClient{exec: exec} +} + +// normalizedClient encapsulates the Spanner client for the normalized schema. +type normalizedClient struct { + exec *SpannerExecutor +} - // Set an initial timestamp synchronously before starting the background loop. - sc.ticker = NewTimestampTicker() - sc.stopCh = make(chan struct{}) - sc.updateTimestamp = sc.fetchAndUpdateTimestamp - if err := sc.updateTimestamp(context.Background()); err != nil { - slog.Error("Error initializing Spanner staleness timestamp", "error", err.Error()) +// NewNormalizedClient creates a new normalizedClient. +func NewNormalizedClient(client SpannerClient) (*normalizedClient, error) { + sc, ok := client.(*defaultSpannerClient) + if !ok { + err := fmt.Errorf("NewNormalizedClient: expected *defaultSpannerClient, got %T", client) + slog.Error("Failed to create normalized client", "error", err) return nil, err } - return sc, nil + return &normalizedClient{exec: sc.exec}, nil } // NewRawSpannerClient creates a new SpannerClient without the schema selector. @@ -94,23 +95,43 @@ func newSpannerDatabaseClient(client *spanner.Client) (*spannerDatabaseClient, e func NewRawSpannerClient(ctx context.Context, spannerConfigYaml, databaseOverride string) (SpannerClient, error) { cfg, err := createSpannerConfig(spannerConfigYaml, databaseOverride) if err != nil { - return nil, fmt.Errorf("failed to create spannerDatabaseClient: %w", err) + return nil, fmt.Errorf("failed to create defaultSpannerClient: %w", err) } client, err := createSpannerClient(ctx, cfg) if err != nil { - return nil, fmt.Errorf("failed to create spannerDatabaseClient: %w", err) + return nil, fmt.Errorf("failed to create defaultSpannerClient: %w", err) } - return newSpannerDatabaseClient(client) + exec, err := NewSpannerExecutor(client) + if err != nil { + return nil, fmt.Errorf("failed to create defaultSpannerClient: %w", err) + } + return newDefaultSpannerClient(exec), nil } // NewSpannerClient creates a new SpannerClient from the config yaml string and an optional database override. // It returns a wrapper client that handles request-time schema dispatching. func NewSpannerClient(ctx context.Context, spannerConfigYaml, databaseOverride string) (SpannerClient, error) { - rawClient, err := NewRawSpannerClient(ctx, spannerConfigYaml, databaseOverride) + cfg, err := createSpannerConfig(spannerConfigYaml, databaseOverride) if err != nil { return nil, err } - return NewSchemaSelectorClient(rawClient) + client, err := createSpannerClient(ctx, cfg) + if err != nil { + return nil, err + } + exec, err := NewSpannerExecutor(client) + if err != nil { + return nil, err + } + + defaultClient := newDefaultSpannerClient(exec) + normalizedClient, err := NewNormalizedClient(defaultClient) + if err != nil { + slog.Error("Failed to create normalized client in NewSpannerClient", "error", err) + return nil, err + } + + return NewSchemaSelectorClient(defaultClient, normalizedClient) } // createSpannerClient creates the database name string and initializes the Spanner client. @@ -145,56 +166,21 @@ func createSpannerConfig(spannerConfigYaml, databaseOverride string) (*SpannerCo return &cfg, nil } -func (sc *spannerDatabaseClient) Id() string { - return sc.client.DatabaseName() +func (sc *defaultSpannerClient) Id() string { + return sc.exec.Id() } // Start starts the background goroutine to periodically fetch the timestamp. -func (sc *spannerDatabaseClient) Start() { - sc.startOnce.Do(func() { - ctx, cancel := context.WithCancel(context.Background()) - - sc.wg.Add(1) - go func() { - // Defer statements are processed in LIFO order. - // Mark the wait group as done. - defer sc.wg.Done() - // Cancel the context to clean up any in-flight operations. - defer cancel() - // Stop the ticker. - defer sc.ticker.Stop() - - for { - select { - case <-sc.stopCh: - return - case <-sc.ticker.C(): - // Ignore the error here to allow the process to continue running - // even if one fetch fails. The previous timestamp remains in cache. - err := sc.updateTimestamp(ctx) - if err != nil { - slog.Error("Error updating Spanner staleness timestamp", "error", err) - } - } - } - }() - }) +func (sc *defaultSpannerClient) Start() { + sc.exec.Start() } // Close closes the Spanner client and stops the background goroutine. -func (sc *spannerDatabaseClient) Close() { - sc.stopOnce.Do(func() { - close(sc.stopCh) - - sc.wg.Wait() - - if sc.client != nil { - sc.client.Close() - } - }) +func (sc *defaultSpannerClient) Close() { + sc.exec.Close() } // GetSdmxObservations is not supported on the default client. -func (sc *spannerDatabaseClient) GetSdmxObservations(ctx context.Context, req *pb.SdmxDataQuery) (*pb.SdmxDataResult, error) { +func (sc *defaultSpannerClient) GetSdmxObservations(ctx context.Context, req *pb.SdmxDataQuery) (*pb.SdmxDataResult, error) { return nil, status.Error(codes.Unimplemented, "SDMX queries are only supported on the normalized schema") } diff --git a/internal/server/spanner/client_normalized.go b/internal/server/spanner/client_normalized.go deleted file mode 100644 index 606c5157a..000000000 --- a/internal/server/spanner/client_normalized.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2026 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package spanner - -import "fmt" - -// TODO(task): Decouple normalizedClient from spannerDatabaseClient by extracting -// common execution and staleness logic into a shared executor. - -// normalizedClient encapsulates the Spanner client for the normalized schema. -type normalizedClient struct { - sc *spannerDatabaseClient -} - -// NewNormalizedClient creates a new normalizedClient. -// It takes a SpannerClient interface and type-asserts it to *spannerDatabaseClient -// to reuse internal helpers like queryStructs. This is a compromise to avoid -// exporting internal implementation details while allowing tests in the golden package -// to construct it. -func NewNormalizedClient(client SpannerClient) (*normalizedClient, error) { - sc, ok := client.(*spannerDatabaseClient) - if !ok { - return nil, fmt.Errorf("NewNormalizedClient: expected *spannerDatabaseClient, got %T", client) - } - return &normalizedClient{sc: sc}, nil -} diff --git a/internal/server/spanner/client_selector.go b/internal/server/spanner/client_selector.go index e3004a22c..aa269bd45 100644 --- a/internal/server/spanner/client_selector.go +++ b/internal/server/spanner/client_selector.go @@ -27,7 +27,7 @@ import ( // schemaSelectorClient dispatches calls to either default or normalized client. type schemaSelectorClient struct { SpannerClient // Embeds the default client - normalized *normalizedClient + normalized NormalizedObservationProvider } // GetObservations overrides the embedded client's GetObservations to dispatch based on schema selection. @@ -77,12 +77,7 @@ func (s *schemaSelectorClient) GetSdmxObservations(ctx context.Context, req *pb. } // NewSchemaSelectorClient creates a new SpannerClient that dispatches calls to either default or normalized client. -func NewSchemaSelectorClient(baseClient SpannerClient) (SpannerClient, error) { - normalizedClient, err := NewNormalizedClient(baseClient) - if err != nil { - return nil, err - } - +func NewSchemaSelectorClient(baseClient SpannerClient, normalizedClient NormalizedObservationProvider) (SpannerClient, error) { return &schemaSelectorClient{ SpannerClient: baseClient, normalized: normalizedClient, diff --git a/internal/server/spanner/executor.go b/internal/server/spanner/executor.go new file mode 100644 index 000000000..dba92580e --- /dev/null +++ b/internal/server/spanner/executor.go @@ -0,0 +1,352 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spanner + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "strings" + "sync" + "sync/atomic" + "time" + + "cloud.google.com/go/spanner" + "github.com/datacommonsorg/mixer/internal/metrics" + "google.golang.org/grpc/codes" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + "google.golang.org/api/iterator" +) + +// SpannerExecutor handles the low-level details of connecting to Spanner and executing queries. +type SpannerExecutor struct { + client *spanner.Client + timestamp atomic.Int64 + ticker Ticker + stopCh chan struct{} + startOnce sync.Once + stopOnce sync.Once + wg sync.WaitGroup + updateTimestamp func(context.Context) error +} + +// NewSpannerExecutor creates a new SpannerExecutor. +func NewSpannerExecutor(client *spanner.Client) (*SpannerExecutor, error) { + se := &SpannerExecutor{ + client: client, + } + + // Set an initial timestamp synchronously before starting the background loop. + se.ticker = NewTimestampTicker() + se.stopCh = make(chan struct{}) + se.updateTimestamp = se.fetchAndUpdateTimestamp + if err := se.updateTimestamp(context.Background()); err != nil { + slog.Error("Error initializing Spanner staleness timestamp", "error", err.Error()) + return nil, err + } + return se, nil +} + +// Id returns the database name. +func (se *SpannerExecutor) Id() string { + return se.client.DatabaseName() +} + +// Start starts the background goroutine to periodically fetch the timestamp. +func (se *SpannerExecutor) Start() { + se.startOnce.Do(func() { + ctx, cancel := context.WithCancel(context.Background()) + + se.wg.Add(1) + go func() { + defer se.wg.Done() + defer cancel() + defer se.ticker.Stop() + + for { + select { + case <-se.stopCh: + return + case <-se.ticker.C(): + err := se.updateTimestamp(ctx) + if err != nil { + slog.Error("Error updating Spanner staleness timestamp", "error", err) + } + } + } + }() + }) +} + +// Close closes the Spanner client and stops the background goroutine. +func (se *SpannerExecutor) Close() { + se.stopOnce.Do(func() { + close(se.stopCh) + se.wg.Wait() + if se.client != nil { + se.client.Close() + } + }) +} + +// fetchAndUpdateTimestamp queries Spanner and updates the timestamp. +func (se *SpannerExecutor) fetchAndUpdateTimestamp(ctx context.Context) error { + queryCtx, cancel := context.WithTimeout(ctx, timestampPollingTimeout) + defer cancel() + + iter := se.client.Single().Query(queryCtx, *GetCompletionTimestampQuery()) + defer iter.Stop() + + row, err := iter.Next() + + var warnMsg string + if err == iterator.Done { + warnMsg = "No valid rows found in IngestionHistory." + } else if code := spanner.ErrCode(err); code == codes.NotFound || + (code == codes.InvalidArgument && strings.Contains(err.Error(), "Table not found: IngestionHistory")) { + warnMsg = "IngestionHistory table not found." + } + + if warnMsg != "" { + slog.Warn(warnMsg + " Falling back to strong reads.") + return nil + } + + if err != nil { + if isTimeoutError(err) { + slog.ErrorContext(queryCtx, "Spanner timestamp polling timed out", + "timeout_duration", timestampPollingTimeout.String(), + "error", err.Error(), + ) + } + return fmt.Errorf("failed to fetch row: %w", err) + } + + var timestamp time.Time + if err := row.Column(0, ×tamp); err != nil { + return fmt.Errorf("failed to read CompletionTimestamp column: %w", err) + } + + se.timestamp.Store(timestamp.UnixNano()) + return nil +} + +func (se *SpannerExecutor) getStalenessTimestamp() (time.Time, error) { + val := se.timestamp.Load() + if val != 0 { + return time.Unix(0, val).UTC(), nil + } + slog.Error("Spanner staleness timestamp not available") + return time.Time{}, fmt.Errorf("error getting staleness timestamp") +} + +func (se *SpannerExecutor) executeQuery( + ctx context.Context, + stmt spanner.Statement, + handleRows func(*spanner.RowIterator) error, +) error { + var queryCtx context.Context + var cancel context.CancelFunc + + if _, ok := ctx.Deadline(); ok { + queryCtx, cancel = context.WithCancel(ctx) + } else { + slog.Warn("Parent context has no deadline; using default API timeout", "timeout", ApiTimeout.String()) + queryCtx, cancel = context.WithTimeout(ctx, ApiTimeout) + } + defer cancel() + + runQuery := func(tb spanner.TimestampBound) error { + metrics.RecordSpannerQuery(queryCtx) + startTime := time.Now() + iter := se.client.Single().WithTimestampBound(tb).Query(queryCtx, stmt) + defer iter.Stop() + err := handleRows(iter) + duration := time.Since(startTime) + + if shouldLogSQL(queryCtx) { + interpolatedSQL := InterpolateSQL(&stmt) + schema := getSchemaName(queryCtx) + fmt.Printf("\n=== [%s] Spanner Query (Took %v) ===\n", schema, duration) + fmt.Println("[Parameterized Query]") + for k, v := range stmt.Params { + jsonVal, _ := json.Marshal(v) + fmt.Printf("SET @%s = %s;\n", k, string(jsonVal)) + } + fmt.Println() + fmt.Println(stmt.SQL) + fmt.Println("\n[Interpolated Query]") + fmt.Println(interpolatedSQL) + fmt.Println("================================================") + } + + if isTimeoutError(err) { + slog.ErrorContext(queryCtx, "Spanner query timed out", + "sql", stmt.SQL, + "error", err.Error(), + ) + } + + return err + } + + ts, err := se.getStalenessTimestamp() + if err != nil { + return runQuery(spanner.StrongRead()) + } + err = runQuery(spanner.ReadTimestamp(ts)) + + if spanner.ErrCode(err) == codes.FailedPrecondition { + slog.Error("Stale read timestamp expired. Falling back to StrongRead.", + "expiredTimestamp", ts.String()) + return runQuery(spanner.StrongRead()) + } + return err +} + +// queryStructs executes a query and maps the results to an input struct. +func (se *SpannerExecutor) queryStructs( + ctx context.Context, + stmt spanner.Statement, + newStruct func() interface{}, + withStruct func(interface{}), +) error { + return se.executeQuery(ctx, stmt, func(iter *spanner.RowIterator) error { + return processRows(iter, newStruct, withStruct) + }) +} + +// queryDynamic executes a dynamically constructed query and returns the results as a slice of string slices. +func (se *SpannerExecutor) queryDynamic( + ctx context.Context, + stmt spanner.Statement, +) ([][]string, error) { + var rowData [][]string + err := se.executeQuery(ctx, stmt, func(iter *spanner.RowIterator) error { + result, err := processDynamicRows(iter) + rowData = result + return err + }) + return rowData, err +} + +// queryCache executes a query and maps the results to an input cache proto. +func queryCache[T proto.Message]( + ctx context.Context, + se *SpannerExecutor, + stmt spanner.Statement, + newProto func() T, +) (map[string]map[string]T, error) { + var data map[string]map[string]T + err := se.executeQuery(ctx, stmt, func(iter *spanner.RowIterator) error { + result, err := processCacheRows(iter, newProto) + data = result + return err + }) + return data, err +} + +func processRows(iter *spanner.RowIterator, newStruct func() interface{}, withStruct func(interface{})) error { + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return fmt.Errorf("failed to fetch row: %w", err) + } + + rowStruct := newStruct() + if err := row.ToStructLenient(rowStruct); err != nil { + return fmt.Errorf("failed to parse row: %w", err) + } + withStruct(rowStruct) + } + return nil +} + +func processDynamicRows(iter *spanner.RowIterator) ([][]string, error) { + rowData := [][]string{} + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return rowData, err + } + + data := []string{} + for i := 0; i < row.Size(); i++ { + var val spanner.GenericColumnValue + if err := row.Column(i, &val); err != nil { + return rowData, err + } + data = append(data, val.Value.GetStringValue()) + } + rowData = append(rowData, data) + } + return rowData, nil +} + +func processCacheRows[T proto.Message](iter *spanner.RowIterator, newProto func() T) (map[string]map[string]T, error) { + results := make(map[string]map[string]T) + unmarshaler := protojson.UnmarshalOptions{DiscardUnknown: true} + + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, fmt.Errorf("failed to fetch row: %w", err) + } + + var key string + if err := row.ColumnByName("key", &key); err != nil { + return nil, fmt.Errorf("failed to read key column: %w", err) + } + + var provenance string + if err := row.ColumnByName("provenance", &provenance); err != nil { + return nil, fmt.Errorf("failed to read provenance column: %w", err) + } + + var jsonStr spanner.NullString + if err := row.ColumnByName("value", &jsonStr); err != nil { + return nil, fmt.Errorf("failed to read value column: %w", err) + } + + if jsonStr.Valid { + msg := newProto() + if err := unmarshaler.Unmarshal([]byte(jsonStr.StringVal), msg); err != nil { + return nil, fmt.Errorf("failed to unmarshal proto: %w", err) + } + + if results[key] == nil { + results[key] = make(map[string]T) + } + results[key][provenance] = msg + } + } + return results, nil +} + +func isTimeoutError(err error) bool { + return spanner.ErrCode(err) == codes.DeadlineExceeded || errors.Is(err, context.DeadlineExceeded) +} diff --git a/internal/server/spanner/query.go b/internal/server/spanner/query.go index 51a58f258..d3d1b9653 100644 --- a/internal/server/spanner/query.go +++ b/internal/server/spanner/query.go @@ -17,8 +17,6 @@ package spanner import ( "context" - "encoding/json" - "errors" "fmt" "log/slog" "net/url" @@ -28,7 +26,6 @@ import ( "time" "cloud.google.com/go/spanner" - "github.com/datacommonsorg/mixer/internal/metrics" pb "github.com/datacommonsorg/mixer/internal/proto" pbv1 "github.com/datacommonsorg/mixer/internal/proto/v1" "github.com/datacommonsorg/mixer/internal/server/datasources" @@ -37,8 +34,6 @@ import ( "github.com/datacommonsorg/mixer/internal/util" "golang.org/x/sync/errgroup" "google.golang.org/api/iterator" - "google.golang.org/grpc/codes" - "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" ) @@ -78,7 +73,7 @@ const ( ) // GetNodeProps retrieves node properties from Spanner given a list of IDs and a direction and returns a map. -func (sc *spannerDatabaseClient) GetNodeProps(ctx context.Context, ids []string, out bool) (map[string][]*Property, error) { +func (sc *defaultSpannerClient) GetNodeProps(ctx context.Context, ids []string, out bool) (map[string][]*Property, error) { props := map[string][]*Property{} if len(ids) == 0 { return props, nil @@ -87,9 +82,8 @@ func (sc *spannerDatabaseClient) GetNodeProps(ctx context.Context, ids []string, props[id] = []*Property{} } - err := queryStructs( + err := sc.exec.queryStructs( ctx, - sc, *GetNodePropsQuery(ids, out), func() interface{} { return &Property{} @@ -108,7 +102,7 @@ func (sc *spannerDatabaseClient) GetNodeProps(ctx context.Context, ids []string, } // GetNodeEdgesByID retrieves node edges from Spanner and returns a map of subjectID to Edges. -func (sc *spannerDatabaseClient) GetNodeEdgesByID(ctx context.Context, ids []string, arc *v2.Arc, pageSize, offset int) (map[string][]*Edge, error) { +func (sc *defaultSpannerClient) GetNodeEdgesByID(ctx context.Context, ids []string, arc *v2.Arc, pageSize, offset int) (map[string][]*Edge, error) { edges := make(map[string][]*Edge) if len(ids) == 0 { return edges, nil @@ -117,9 +111,8 @@ func (sc *spannerDatabaseClient) GetNodeEdgesByID(ctx context.Context, ids []str edges[id] = []*Edge{} } - err := queryStructs( + err := sc.exec.queryStructs( ctx, - sc, *GetNodeEdgesByIDQuery(ids, arc, pageSize, offset), func() interface{} { return &Edge{} @@ -138,15 +131,14 @@ func (sc *spannerDatabaseClient) GetNodeEdgesByID(ctx context.Context, ids []str } // GetObservations retrieves observations from Spanner given a list of variables and entities. -func (sc *spannerDatabaseClient) GetObservations(ctx context.Context, variables []string, entities []string) ([]*Observation, error) { +func (sc *defaultSpannerClient) GetObservations(ctx context.Context, variables []string, entities []string) ([]*Observation, error) { var observations []*Observation if len(entities) == 0 { return nil, fmt.Errorf("entity must be specified") } - err := queryStructs( + err := sc.exec.queryStructs( ctx, - sc, *GetObservationsQuery(variables, entities), func() interface{} { return &Observation{} @@ -165,24 +157,23 @@ func (sc *spannerDatabaseClient) GetObservations(ctx context.Context, variables // CheckVariableExistence checks for the existence of observations for the given variables and entities. // Returns a slice of rows, where each row contains [variable, entity] that has at least one observation. -func (sc *spannerDatabaseClient) CheckVariableExistence(ctx context.Context, variables []string, entities []string) ([][]string, error) { +func (sc *defaultSpannerClient) CheckVariableExistence(ctx context.Context, variables []string, entities []string) ([][]string, error) { stmt, err := FilterStatVarsByEntityQuery(variables, entities) if err != nil { return nil, err } - return queryDynamic(ctx, sc, *stmt) + return sc.exec.queryDynamic(ctx, *stmt) } // GetObservationsContainedInPlace retrieves observations from Spanner given a list of variables and an entity expression. -func (sc *spannerDatabaseClient) GetObservationsContainedInPlace(ctx context.Context, variables []string, containedInPlace *v2.ContainedInPlace) ([]*Observation, error) { +func (sc *defaultSpannerClient) GetObservationsContainedInPlace(ctx context.Context, variables []string, containedInPlace *v2.ContainedInPlace) ([]*Observation, error) { var observations []*Observation if len(variables) == 0 || containedInPlace == nil { return observations, nil } - err := queryStructs( + err := sc.exec.queryStructs( ctx, - sc, *GetObservationsContainedInPlaceQuery(variables, containedInPlace), func() interface{} { return &Observation{} @@ -202,15 +193,14 @@ func (sc *spannerDatabaseClient) GetObservationsContainedInPlace(ctx context.Con // SearchNodes searches nodes in the graph based on the query and optionally the types. // If the types array is empty, it searches across nodes of all types. // A maximum of 100 results are returned. -func (sc *spannerDatabaseClient) SearchNodes(ctx context.Context, query string, types []string) ([]*SearchNode, error) { +func (sc *defaultSpannerClient) SearchNodes(ctx context.Context, query string, types []string) ([]*SearchNode, error) { var nodes []*SearchNode if query == "" { return nodes, nil } - err := queryStructs( + err := sc.exec.queryStructs( ctx, - sc, *SearchNodesQuery(query, types), func() interface{} { return &SearchNode{} @@ -228,7 +218,7 @@ func (sc *spannerDatabaseClient) SearchNodes(ctx context.Context, query string, } // ResolveByID fetches ID resolution candidates for a list of input nodes and in and out properties and returns a map of node to candidates. -func (sc *spannerDatabaseClient) ResolveByID(ctx context.Context, nodes []string, in, out string) (map[string][]string, error) { +func (sc *defaultSpannerClient) ResolveByID(ctx context.Context, nodes []string, in, out string) (map[string][]string, error) { nodeToCandidates := make(map[string][]string) if len(nodes) == 0 { return nodeToCandidates, nil @@ -242,9 +232,8 @@ func (sc *spannerDatabaseClient) ResolveByID(ctx context.Context, nodes []string valueMap[value] = node } - err := queryStructs( + err := sc.exec.queryStructs( ctx, - sc, *ResolveByIDQuery(nodes, in, out), func() interface{} { return &ResolutionCandidate{} @@ -263,9 +252,9 @@ func (sc *spannerDatabaseClient) ResolveByID(ctx context.Context, nodes []string } // GetEventCollectionDate retrieves event collection dates from Spanner. -func (sc *spannerDatabaseClient) GetEventCollectionDate(ctx context.Context, placeID, eventType string) ([]string, error) { +func (sc *defaultSpannerClient) GetEventCollectionDate(ctx context.Context, placeID, eventType string) ([]string, error) { stmt := GetEventCollectionDateQuery(placeID, eventType) - rows, err := queryDynamic(ctx, sc, *stmt) + rows, err := sc.exec.queryDynamic(ctx, *stmt) if err != nil { return nil, err } @@ -280,7 +269,7 @@ func (sc *spannerDatabaseClient) GetEventCollectionDate(ctx context.Context, pla } // GetEventCollection retrieves and filters event collection from Spanner. -func (sc *spannerDatabaseClient) GetEventCollection(ctx context.Context, req *pbv1.EventCollectionRequest) (*pbv1.EventCollection, error) { +func (sc *defaultSpannerClient) GetEventCollection(ctx context.Context, req *pbv1.EventCollectionRequest) (*pbv1.EventCollection, error) { // Get event DCIDs eventRows, err := sc.GetEventCollectionDcids(ctx, req.AffectedPlaceDcid, req.EventType, req.Date) if err != nil { @@ -312,7 +301,7 @@ func (sc *spannerDatabaseClient) GetEventCollection(ctx context.Context, req *pb return res, nil } -func (sc *spannerDatabaseClient) populateProvenanceInfo(ctx context.Context, res *pbv1.EventCollection) error { +func (sc *defaultSpannerClient) populateProvenanceInfo(ctx context.Context, res *pbv1.EventCollection) error { provDcids := []string{} seen := map[string]bool{} for _, event := range res.Events { @@ -517,9 +506,9 @@ func keepEvent(event *pbv1.EventCollection_Event, req *pbv1.EventCollectionReque } // GetEventCollectionDcids retrieves event DCIDs from Spanner. -func (sc *spannerDatabaseClient) GetEventCollectionDcids(ctx context.Context, placeID, eventType, date string) ([]EventIdWithMagnitudeDcid, error) { +func (sc *defaultSpannerClient) GetEventCollectionDcids(ctx context.Context, placeID, eventType, date string) ([]EventIdWithMagnitudeDcid, error) { stmt := GetEventCollectionDcidsQuery(placeID, eventType, date) - rows, err := queryDynamic(ctx, sc, *stmt) + rows, err := sc.exec.queryDynamic(ctx, *stmt) if err != nil { return nil, err } @@ -596,16 +585,16 @@ func parseAndSortEvents(rows []EventIdWithMagnitudeDcid, eventType string) []str return res } -func (sc *spannerDatabaseClient) Sparql(ctx context.Context, nodes []types.Node, queries []*types.Query, opts *types.QueryOptions) ([][]string, error) { +func (sc *defaultSpannerClient) Sparql(ctx context.Context, nodes []types.Node, queries []*types.Query, opts *types.QueryOptions) ([][]string, error) { query, err := SparqlQuery(nodes, queries, opts) if err != nil { return nil, fmt.Errorf("error building sparql query: %v", err) } - return queryDynamic(ctx, sc, *query) + return sc.exec.queryDynamic(ctx, *query) } -func (sc *spannerDatabaseClient) GetProvenanceSummary(ctx context.Context, variables []string) (map[string]map[string]*pb.StatVarSummary_ProvenanceSummary, error) { +func (sc *defaultSpannerClient) GetProvenanceSummary(ctx context.Context, variables []string) (map[string]map[string]*pb.StatVarSummary_ProvenanceSummary, error) { if len(variables) == 0 { return map[string]map[string]*pb.StatVarSummary_ProvenanceSummary{}, nil @@ -613,7 +602,7 @@ func (sc *spannerDatabaseClient) GetProvenanceSummary(ctx context.Context, varia results, err := queryCache( ctx, - sc, + sc.exec, *GetCacheDataQuery(TypeProvenanceSummary, variables), func() *pb.StatVarSummary_ProvenanceSummary { return &pb.StatVarSummary_ProvenanceSummary{} @@ -627,9 +616,9 @@ func (sc *spannerDatabaseClient) GetProvenanceSummary(ctx context.Context, varia } // GetTermEmbeddingQuery retrieves embeddings from Spanner for a given query. -func (sc *spannerDatabaseClient) GetTermEmbeddingQuery(ctx context.Context, modelName, searchLabel, taskType string) ([]float64, error) { +func (sc *defaultSpannerClient) GetTermEmbeddingQuery(ctx context.Context, modelName, searchLabel, taskType string) ([]float64, error) { embeddings := []float64{} - err := sc.executeQuery(ctx, *GetTermEmbeddingQuery(modelName, searchLabel, taskType), func(iter *spanner.RowIterator) error { + err := sc.exec.executeQuery(ctx, *GetTermEmbeddingQuery(modelName, searchLabel, taskType), func(iter *spanner.RowIterator) error { row, err := iter.Next() if err == iterator.Done { return nil @@ -643,7 +632,7 @@ func (sc *spannerDatabaseClient) GetTermEmbeddingQuery(ctx context.Context, mode } // FilterNodesByTypes filters a list of nodes by types and returns a map of node to matched types. -func (sc *spannerDatabaseClient) FilterNodesByTypes(ctx context.Context, nodes []string, typeFilters []string) (map[string][]string, error) { +func (sc *defaultSpannerClient) FilterNodesByTypes(ctx context.Context, nodes []string, typeFilters []string) (map[string][]string, error) { if len(nodes) == 0 { return map[string][]string{}, nil } @@ -656,9 +645,8 @@ func (sc *spannerDatabaseClient) FilterNodesByTypes(ctx context.Context, nodes [ MatchedTypes []string `spanner:"matched_types"` } - err := queryStructs( + err := sc.exec.queryStructs( ctx, - sc, *stmt, func() interface{} { return &rowResult{} }, func(rowStruct interface{}) { @@ -674,11 +662,10 @@ func (sc *spannerDatabaseClient) FilterNodesByTypes(ctx context.Context, nodes [ } // VectorSearchQuery performs vector similarity search in Spanner. -func (sc *spannerDatabaseClient) VectorSearchQuery(ctx context.Context, tableName string, limit int, embeddings []float64, numLeaves int, threshold float64, nodeTypes []string) ([]*VectorSearchResult, error) { +func (sc *defaultSpannerClient) VectorSearchQuery(ctx context.Context, tableName string, limit int, embeddings []float64, numLeaves int, threshold float64, nodeTypes []string) ([]*VectorSearchResult, error) { var results []*VectorSearchResult - err := queryStructs( + err := sc.exec.queryStructs( ctx, - sc, *VectorSearchQuery(tableName, limit, embeddings, numLeaves, threshold, nodeTypes), func() interface{} { return &VectorSearchResult{} @@ -692,15 +679,14 @@ func (sc *spannerDatabaseClient) VectorSearchQuery(ctx context.Context, tableNam } // GetStatVarGroupNode fetches StatVarGroupNode info from Spanner. -func (sc *spannerDatabaseClient) GetStatVarGroupNode(ctx context.Context, nodes []string, includeDefinitions bool) ([]*StatVarGroupNode, error) { +func (sc *defaultSpannerClient) GetStatVarGroupNode(ctx context.Context, nodes []string, includeDefinitions bool) ([]*StatVarGroupNode, error) { var svgNodes []*StatVarGroupNode if len(nodes) == 0 { return svgNodes, nil } - err := queryStructs( + err := sc.exec.queryStructs( ctx, - sc, *GetStatVarGroupNodeQuery(nodes, includeDefinitions), func() interface{} { return &StatVarGroupNode{} @@ -717,7 +703,7 @@ func (sc *spannerDatabaseClient) GetStatVarGroupNode(ctx context.Context, nodes } // GetFilteredStatVarGroupNode fetches filtered StatVarGroupNode info from Spanner. -func (sc *spannerDatabaseClient) GetFilteredStatVarGroupNode(ctx context.Context, nodes []string, constrainedPlaces []string, constrainedImport string, numEntitiesExistence int, includeDefinitions bool) (map[string]*FilteredStatVarGroupNode, error) { +func (sc *defaultSpannerClient) GetFilteredStatVarGroupNode(ctx context.Context, nodes []string, constrainedPlaces []string, constrainedImport string, numEntitiesExistence int, includeDefinitions bool) (map[string]*FilteredStatVarGroupNode, error) { response := map[string]*FilteredStatVarGroupNode{} errGroup, errCtx := errgroup.WithContext(ctx) errGroup.SetLimit(maxConcurrentFilteredSVGGoroutines) // Limit the number of concurrent goroutines to avoid overwhelming Spanner with too many requests. @@ -753,7 +739,7 @@ func (sc *spannerDatabaseClient) GetFilteredStatVarGroupNode(ctx context.Context } // getSingleFilteredStatVarGroupNode fetches the relevant info to build a single filtered StatVarGroupNode from Spanner. -func (sc *spannerDatabaseClient) getSingleFilteredStatVarGroupNode(ctx context.Context, node string, constrainedPlaces []string, constrainedImport string, numEntitiesExistence int, includeDefinitions bool) (*FilteredStatVarGroupNode, error) { +func (sc *defaultSpannerClient) getSingleFilteredStatVarGroupNode(ctx context.Context, node string, constrainedPlaces []string, constrainedImport string, numEntitiesExistence int, includeDefinitions bool) (*FilteredStatVarGroupNode, error) { filteredStatVarGroupNode := &FilteredStatVarGroupNode{} errGroup, errCtx := errgroup.WithContext(ctx) svgChildChan := make(chan []*SVGChild, 1) @@ -762,9 +748,8 @@ func (sc *spannerDatabaseClient) getSingleFilteredStatVarGroupNode(ctx context.C errGroup.Go(func() error { var svgChildren []*SVGChild - err := queryStructs( + err := sc.exec.queryStructs( errCtx, - sc, *GetSVGChildrenQuery(node, includeDefinitions), func() interface{} { return &SVGChild{} @@ -782,9 +767,8 @@ func (sc *spannerDatabaseClient) getSingleFilteredStatVarGroupNode(ctx context.C errGroup.Go(func() error { var childSVs []*ChildSV - err := queryStructs( + err := sc.exec.queryStructs( errCtx, - sc, *GetFilteredSVGChildrenQuery(templateSV, node, constrainedPlaces, constrainedImport, numEntitiesExistence, includeDefinitions), func() interface{} { return &ChildSV{} @@ -802,9 +786,8 @@ func (sc *spannerDatabaseClient) getSingleFilteredStatVarGroupNode(ctx context.C errGroup.Go(func() error { var childSVGs []*ChildSVG - err := queryStructs( + err := sc.exec.queryStructs( errCtx, - sc, *GetFilteredSVGChildrenQuery(templateSVG, node, constrainedPlaces, constrainedImport, numEntitiesExistence, includeDefinitions), func() interface{} { return &ChildSVG{} @@ -836,14 +819,14 @@ func (sc *spannerDatabaseClient) getSingleFilteredStatVarGroupNode(ctx context.C } // GetFilteredTopic fetches the relevant info to build a filtered Topic response from Spanner. -func (sc *spannerDatabaseClient) GetFilteredTopic(ctx context.Context, nodes []string, constrainedPlaces []string, constrainedImport string, numEntitiesExistence int) (map[string]int, error) { +func (sc *defaultSpannerClient) GetFilteredTopic(ctx context.Context, nodes []string, constrainedPlaces []string, constrainedImport string, numEntitiesExistence int) (map[string]int, error) { counts := make(map[string]int, len(nodes)) for _, node := range nodes { counts[node] = 0 } stmt := GetFilteredTopicChildrenQuery(nodes, constrainedPlaces, constrainedImport, numEntitiesExistence) - err := sc.executeQuery(ctx, *stmt, func(iter *spanner.RowIterator) error { + err := sc.exec.executeQuery(ctx, *stmt, func(iter *spanner.RowIterator) error { for { row, err := iter.Next() if err == iterator.Done { @@ -869,261 +852,5 @@ func (sc *spannerDatabaseClient) GetFilteredTopic(ctx context.Context, nodes []s return counts, nil } -// fetchAndUpdateTimestamp queries Spanner and updates the timestamp. -func (sc *spannerDatabaseClient) fetchAndUpdateTimestamp(ctx context.Context) error { - queryCtx, cancel := context.WithTimeout(ctx, timestampPollingTimeout) - defer cancel() - iter := sc.client.Single().Query(queryCtx, *GetCompletionTimestampQuery()) - defer iter.Stop() - row, err := iter.Next() - - // Handle missing or empty table cases gracefully - var warnMsg string - if err == iterator.Done { - warnMsg = "No valid rows found in IngestionHistory." - } else if code := spanner.ErrCode(err); code == codes.NotFound || - (code == codes.InvalidArgument && strings.Contains(err.Error(), "Table not found: IngestionHistory")) { - warnMsg = "IngestionHistory table not found." - } - - if warnMsg != "" { - slog.Warn(warnMsg + " Falling back to strong reads.") - return nil - } - - if err != nil { - if isTimeoutError(err) { - slog.ErrorContext(queryCtx, "Spanner timestamp polling timed out", - "timeout_duration", timestampPollingTimeout.String(), - "error", err.Error(), - ) - } - return fmt.Errorf("failed to fetch row: %w", err) - } - - var timestamp time.Time - if err := row.Column(0, ×tamp); err != nil { - return fmt.Errorf("failed to read CompletionTimestamp column: %w", err) - } - - sc.timestamp.Store(timestamp.UnixNano()) - return nil -} - -func (sc *spannerDatabaseClient) getStalenessTimestamp() (time.Time, error) { - val := sc.timestamp.Load() - if val != 0 { - return time.Unix(0, val).UTC(), nil - } - slog.Error("Spanner staleness timestamp not available") - return time.Time{}, fmt.Errorf("error getting staleness timestamp") -} - -func (sc *spannerDatabaseClient) executeQuery( - ctx context.Context, - stmt spanner.Statement, - handleRows func(*spanner.RowIterator) error, -) error { - var queryCtx context.Context - var cancel context.CancelFunc - - if _, ok := ctx.Deadline(); ok { - queryCtx, cancel = context.WithCancel(ctx) - } else { - // Fallback if the parent context surprisingly has no deadline. - // Using the default API timeout. - slog.Warn("Parent context has no deadline; using default API timeout", "timeout", ApiTimeout.String()) - queryCtx, cancel = context.WithTimeout(ctx, ApiTimeout) - } - defer cancel() - - runQuery := func(tb spanner.TimestampBound) error { - metrics.RecordSpannerQuery(queryCtx) - startTime := time.Now() - iter := sc.client.Single().WithTimestampBound(tb).Query(queryCtx, stmt) - defer iter.Stop() - err := handleRows(iter) - duration := time.Since(startTime) - - if shouldLogSQL(queryCtx) { - interpolatedSQL := InterpolateSQL(&stmt) - schema := getSchemaName(queryCtx) - fmt.Printf("\n=== [%s] Spanner Query (Took %v) ===\n", schema, duration) - fmt.Println("[Parameterized Query]") - for k, v := range stmt.Params { - jsonVal, _ := json.Marshal(v) - fmt.Printf("SET @%s = %s;\n", k, string(jsonVal)) - } - fmt.Println() - fmt.Println(stmt.SQL) - fmt.Println("\n[Interpolated Query]") - fmt.Println(interpolatedSQL) - fmt.Println("================================================") - } - - // Log slow Spanner queries that timed out. - if isTimeoutError(err) { - slog.ErrorContext(queryCtx, "Spanner query timed out", - "sql", stmt.SQL, - "error", err.Error(), - ) - } - - return err - } - - ts, err := sc.getStalenessTimestamp() - if err != nil { - return runQuery(spanner.StrongRead()) - } - err = runQuery(spanner.ReadTimestamp(ts)) - - // Log error if timestamp is older than retention and fall back to strong read. - if spanner.ErrCode(err) == codes.FailedPrecondition { - slog.Error("Stale read timestamp expired. Falling back to StrongRead.", - "expiredTimestamp", ts.String()) - return runQuery(spanner.StrongRead()) - } - return err -} - -// queryStructs executes a query and maps the results to an input struct. -func queryStructs( - ctx context.Context, - sc *spannerDatabaseClient, - stmt spanner.Statement, - newStruct func() interface{}, - withStruct func(interface{}), -) error { - return sc.executeQuery(ctx, stmt, func(iter *spanner.RowIterator) error { - return processRows(iter, newStruct, withStruct) - }) -} - -// queryDynamic executes a dynamically constructed query and returns the results as a slice of string slices. -func queryDynamic( - ctx context.Context, - sc *spannerDatabaseClient, - stmt spanner.Statement, -) ([][]string, error) { - var rowData [][]string - err := sc.executeQuery(ctx, stmt, func(iter *spanner.RowIterator) error { - result, err := processDynamicRows(iter) - rowData = result - return err - }) - return rowData, err -} - -// queryCache executes a query and maps the results to an input cache proto. -func queryCache[T proto.Message]( - ctx context.Context, - sc *spannerDatabaseClient, - stmt spanner.Statement, - newProto func() T, -) (map[string]map[string]T, error) { - var data map[string]map[string]T - err := sc.executeQuery(ctx, stmt, func(iter *spanner.RowIterator) error { - result, err := processCacheRows(iter, newProto) - data = result - return err - }) - return data, err -} - -func processRows(iter *spanner.RowIterator, newStruct func() interface{}, withStruct func(interface{})) error { - for { - row, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - return fmt.Errorf("failed to fetch row: %w", err) - } - - rowStruct := newStruct() - if err := row.ToStructLenient(rowStruct); err != nil { - return fmt.Errorf("failed to parse row: %w", err) - } - withStruct(rowStruct) - } - - return nil -} - -// processDynamicRows processes rows from dynamically constructed queries. -func processDynamicRows(iter *spanner.RowIterator) ([][]string, error) { - rowData := [][]string{} - for { - row, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - return rowData, err - } - - data := []string{} - for i := 0; i < row.Size(); i++ { - var val spanner.GenericColumnValue - if err := row.Column(i, &val); err != nil { - return rowData, err - } - data = append(data, val.Value.GetStringValue()) - } - rowData = append(rowData, data) - } - return rowData, nil -} - -// processCacheRows processes rows and maps them to a proto struct. -func processCacheRows[T proto.Message](iter *spanner.RowIterator, newProto func() T) (map[string]map[string]T, error) { - results := make(map[string]map[string]T) - unmarshaler := protojson.UnmarshalOptions{DiscardUnknown: true} - - for { - row, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("failed to fetch row: %w", err) - } - - var key string - if err := row.ColumnByName("key", &key); err != nil { - return nil, fmt.Errorf("failed to read key column: %w", err) - } - - var provenance string - if err := row.ColumnByName("provenance", &provenance); err != nil { - return nil, fmt.Errorf("failed to read provenance column: %w", err) - } - - var jsonStr spanner.NullString - if err := row.ColumnByName("value", &jsonStr); err != nil { - return nil, fmt.Errorf("failed to read value column: %w", err) - } - - if jsonStr.Valid { - msg := newProto() - if err := unmarshaler.Unmarshal([]byte(jsonStr.StringVal), msg); err != nil { - return nil, fmt.Errorf("failed to unmarshal proto: %w", err) - } - - if results[key] == nil { - results[key] = make(map[string]T) - } - results[key][provenance] = msg - } - } - - return results, nil -} - -// isTimeoutError checks if an error is a timeout error from Spanner or context. -func isTimeoutError(err error) bool { - return spanner.ErrCode(err) == codes.DeadlineExceeded || errors.Is(err, context.DeadlineExceeded) -} diff --git a/internal/server/spanner/query_normalized.go b/internal/server/spanner/query_normalized.go index 4fefbba66..1f4ce6b24 100644 --- a/internal/server/spanner/query_normalized.go +++ b/internal/server/spanner/query_normalized.go @@ -47,7 +47,7 @@ func (nc *normalizedClient) fetchRawObservations(ctx context.Context, variables stmt := GetNormalizedObservationsQuery(variables, entities) var rawObs []*rawObservation - err := queryStructs(ctx, nc.sc, *stmt, func() interface{} { return &rawObservation{} }, func(row interface{}) { + err := nc.exec.queryStructs(ctx, *stmt, func() interface{} { return &rawObservation{} }, func(row interface{}) { rawObs = append(rawObs, row.(*rawObservation)) }) return rawObs, err @@ -95,7 +95,7 @@ func (nc *normalizedClient) CheckVariableExistence(ctx context.Context, variable if err != nil { return nil, err } - return queryDynamic(ctx, nc.sc, *stmt) + return nc.exec.queryDynamic(ctx, *stmt) } // GetObservationsContainedInPlace retrieves observations for entities contained in a place @@ -122,7 +122,7 @@ func (nc *normalizedClient) fetchRawObservationsContainedInPlace(ctx context.Con stmt := GetNormalizedObservationsContainedInPlaceQuery(variables, containedInPlace) var rawObs []*rawObservation - err := queryStructs(ctx, nc.sc, *stmt, func() interface{} { return &rawObservation{} }, func(row interface{}) { + err := nc.exec.queryStructs(ctx, *stmt, func() interface{} { return &rawObservation{} }, func(row interface{}) { rawObs = append(rawObs, row.(*rawObservation)) }) return rawObs, err @@ -143,7 +143,7 @@ func (nc *normalizedClient) GetSdmxObservations(ctx context.Context, req *pb.Sdm stmt := GetSdmxObservationsQuery(req) var rawObs []*rawObservation - err := queryStructs(ctx, nc.sc, *stmt, func() interface{} { return &rawObservation{} }, func(row interface{}) { + err := nc.exec.queryStructs(ctx, *stmt, func() interface{} { return &rawObservation{} }, func(row interface{}) { rawObs = append(rawObs, row.(*rawObservation)) }) if err != nil { diff --git a/internal/server/spanner/timestamp_test.go b/internal/server/spanner/timestamp_test.go index d9ba21f2e..5d6cd7e72 100644 --- a/internal/server/spanner/timestamp_test.go +++ b/internal/server/spanner/timestamp_test.go @@ -53,7 +53,7 @@ func TestTimestampUpdated(t *testing.T) { startTime := time.Date(2025, time.January, 1, 10, 0, 0, 0, time.UTC) updateDone := make(chan bool, 1) - sc := &spannerDatabaseClient{ + sc := &SpannerExecutor{ ticker: mockTicker, stopCh: make(chan struct{}), } @@ -96,7 +96,7 @@ func TestTimestampUpdateFailure(t *testing.T) { startTime := time.Date(2025, time.January, 1, 10, 0, 0, 0, time.UTC) updateDone := make(chan bool, 1) - sc := &spannerDatabaseClient{ + sc := &SpannerExecutor{ ticker: mockTicker, stopCh: make(chan struct{}), } diff --git a/test/setup.go b/test/setup.go index c7b6658da..bef35ab0c 100644 --- a/test/setup.go +++ b/test/setup.go @@ -430,22 +430,17 @@ func newSpannerClient(ctx context.Context, spannerGraphInfoYamlPath string) span if err != nil { log.Fatalf("Failed to read spanner yaml: %v", err) } - // Don't override spannerGraphInfoYaml.database for testing. - spannerClient, err := spanner.NewRawSpannerClient(ctx, string(spannerGraphInfoYaml), "") + // Use NewSpannerClient to get the full setup with selector. + spannerClient, err := spanner.NewSpannerClient(ctx, string(spannerGraphInfoYaml), "") if err != nil { log.Fatalf("Failed to create SpannerClient: %v", err) } - // Use stale reads for testing. spannerClient.Start() return spannerClient } // NewSchemaSelectorSpannerClient creates a new test schema selector spanner client. func NewSchemaSelectorSpannerClient(t *testing.T) spanner.SpannerClient { - baseClient := NewNormalizedSpannerClient(t) - selectorClient, err := spanner.NewSchemaSelectorClient(baseClient) - if err != nil { - t.Fatalf("Failed to create SchemaSelectorClient: %v", err) - } - return selectorClient + // NewNormalizedSpannerClient already returns the selector client because newSpannerClient uses NewSpannerClient! + return NewNormalizedSpannerClient(t) } From 3e13a35129a19b84d502ef27575b139483514334 Mon Sep 17 00:00:00 2001 From: Christie Ellks Date: Thu, 14 May 2026 22:20:22 -0700 Subject: [PATCH 2/8] embedding --- internal/server/spanner/client.go | 71 +++++------ internal/server/spanner/client_selector.go | 117 ------------------ internal/server/spanner/executor.go | 19 +++ .../spanner/golden/query_normalized_test.go | 20 +-- internal/server/spanner/query.go | 42 +++---- internal/server/spanner/query_normalized.go | 35 +++++- 6 files changed, 106 insertions(+), 198 deletions(-) delete mode 100644 internal/server/spanner/client_selector.go diff --git a/internal/server/spanner/client.go b/internal/server/spanner/client.go index 609933609..7617dfa2a 100644 --- a/internal/server/spanner/client.go +++ b/internal/server/spanner/client.go @@ -56,56 +56,55 @@ type SpannerClient interface { Close() } -// NormalizedObservationProvider defines the subset of methods supported by the normalized schema path. -type NormalizedObservationProvider interface { - GetObservations(ctx context.Context, variables []string, entities []string) ([]*Observation, error) - CheckVariableExistence(ctx context.Context, variables []string, entities []string) ([][]string, error) - GetObservationsContainedInPlace(ctx context.Context, variables []string, containedInPlace *v2.ContainedInPlace) ([]*Observation, error) - GetSdmxObservations(ctx context.Context, req *pb.SdmxDataQuery) (*pb.SdmxDataResult, error) -} - -// defaultSpannerClient encapsulates the Spanner client that directly interacts with the Spanner database. -type defaultSpannerClient struct { +// standardSpannerClient encapsulates the Spanner client that directly interacts with the Spanner database. +type standardSpannerClient struct { exec *SpannerExecutor } -// newDefaultSpannerClient creates a new defaultSpannerClient. -func newDefaultSpannerClient(exec *SpannerExecutor) *defaultSpannerClient { - return &defaultSpannerClient{exec: exec} +// newStandardSpannerClient creates a new standardSpannerClient. +func newStandardSpannerClient(exec *SpannerExecutor) *standardSpannerClient { + return &standardSpannerClient{exec: exec} } -// normalizedClient encapsulates the Spanner client for the normalized schema. -type normalizedClient struct { - exec *SpannerExecutor +// normalizedSchemaClient encapsulates the Spanner client for the normalized schema. +// It embeds SpannerClient to inherit default behavior and only overrides specific methods. +type normalizedSchemaClient struct { + SpannerClient // Embeds standardSpannerClient + exec *SpannerExecutor } -// NewNormalizedClient creates a new normalizedClient. -func NewNormalizedClient(client SpannerClient) (*normalizedClient, error) { - sc, ok := client.(*defaultSpannerClient) +// NewNormalizedClient creates a new normalizedSchemaClient. +func NewNormalizedClient(client SpannerClient) *normalizedSchemaClient { + sc, ok := client.(*standardSpannerClient) if !ok { - err := fmt.Errorf("NewNormalizedClient: expected *defaultSpannerClient, got %T", client) - slog.Error("Failed to create normalized client", "error", err) - return nil, err + panic(fmt.Sprintf("NewNormalizedClient: expected *standardSpannerClient, got %T", client)) + } + return &normalizedSchemaClient{ + SpannerClient: client, + exec: sc.exec, } - return &normalizedClient{exec: sc.exec}, nil } +// Force compiler that all methods required by the interface are implemented by clients +var _ SpannerClient = (*standardSpannerClient)(nil) +var _ SpannerClient = (*normalizedSchemaClient)(nil) + // NewRawSpannerClient creates a new SpannerClient without the schema selector. // This is intended for testing and internal use where a direct client is needed. func NewRawSpannerClient(ctx context.Context, spannerConfigYaml, databaseOverride string) (SpannerClient, error) { cfg, err := createSpannerConfig(spannerConfigYaml, databaseOverride) if err != nil { - return nil, fmt.Errorf("failed to create defaultSpannerClient: %w", err) + return nil, fmt.Errorf("failed to create standardSpannerClient: %w", err) } client, err := createSpannerClient(ctx, cfg) if err != nil { - return nil, fmt.Errorf("failed to create defaultSpannerClient: %w", err) + return nil, fmt.Errorf("failed to create standardSpannerClient: %w", err) } exec, err := NewSpannerExecutor(client) if err != nil { - return nil, fmt.Errorf("failed to create defaultSpannerClient: %w", err) + return nil, fmt.Errorf("failed to create standardSpannerClient: %w", err) } - return newDefaultSpannerClient(exec), nil + return newStandardSpannerClient(exec), nil } // NewSpannerClient creates a new SpannerClient from the config yaml string and an optional database override. @@ -124,14 +123,10 @@ func NewSpannerClient(ctx context.Context, spannerConfigYaml, databaseOverride s return nil, err } - defaultClient := newDefaultSpannerClient(exec) - normalizedClient, err := NewNormalizedClient(defaultClient) - if err != nil { - slog.Error("Failed to create normalized client in NewSpannerClient", "error", err) - return nil, err - } + defaultClient := newStandardSpannerClient(exec) + normalizedSchemaClient := NewNormalizedClient(defaultClient) - return NewSchemaSelectorClient(defaultClient, normalizedClient) + return normalizedSchemaClient, nil } // createSpannerClient creates the database name string and initializes the Spanner client. @@ -166,21 +161,21 @@ func createSpannerConfig(spannerConfigYaml, databaseOverride string) (*SpannerCo return &cfg, nil } -func (sc *defaultSpannerClient) Id() string { +func (sc *standardSpannerClient) Id() string { return sc.exec.Id() } // Start starts the background goroutine to periodically fetch the timestamp. -func (sc *defaultSpannerClient) Start() { +func (sc *standardSpannerClient) Start() { sc.exec.Start() } // Close closes the Spanner client and stops the background goroutine. -func (sc *defaultSpannerClient) Close() { +func (sc *standardSpannerClient) Close() { sc.exec.Close() } // GetSdmxObservations is not supported on the default client. -func (sc *defaultSpannerClient) GetSdmxObservations(ctx context.Context, req *pb.SdmxDataQuery) (*pb.SdmxDataResult, error) { +func (sc *standardSpannerClient) GetSdmxObservations(ctx context.Context, req *pb.SdmxDataQuery) (*pb.SdmxDataResult, error) { return nil, status.Error(codes.Unimplemented, "SDMX queries are only supported on the normalized schema") } diff --git a/internal/server/spanner/client_selector.go b/internal/server/spanner/client_selector.go deleted file mode 100644 index aa269bd45..000000000 --- a/internal/server/spanner/client_selector.go +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2026 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package spanner - -import ( - "context" - "log/slog" - - pb "github.com/datacommonsorg/mixer/internal/proto" - v2 "github.com/datacommonsorg/mixer/internal/server/v2" - "github.com/datacommonsorg/mixer/internal/util" - "google.golang.org/grpc/metadata" -) - -// schemaSelectorClient dispatches calls to either default or normalized client. -type schemaSelectorClient struct { - SpannerClient // Embeds the default client - normalized NormalizedObservationProvider -} - -// GetObservations overrides the embedded client's GetObservations to dispatch based on schema selection. -func (s *schemaSelectorClient) GetObservations(ctx context.Context, variables []string, entities []string) ([]*Observation, error) { - if useNormalizedSchema(ctx) { - logNormalizedInvocation("GetObservations", - "num_variables", len(variables), - "num_entities", len(entities), - ) - return s.normalized.GetObservations(ctx, variables, entities) - } - return s.SpannerClient.GetObservations(ctx, variables, entities) -} - -// CheckVariableExistence overrides the embedded client's CheckVariableExistence to dispatch based on schema selection. -func (s *schemaSelectorClient) CheckVariableExistence(ctx context.Context, variables []string, entities []string) ([][]string, error) { - if useNormalizedSchema(ctx) { - logNormalizedInvocation("CheckVariableExistence", - "num_variables", len(variables), - "num_entities", len(entities), - ) - return s.normalized.CheckVariableExistence(ctx, variables, entities) - } - return s.SpannerClient.CheckVariableExistence(ctx, variables, entities) -} - -// GetObservationsContainedInPlace overrides the embedded client's GetObservationsContainedInPlace to dispatch based on schema selection. -func (s *schemaSelectorClient) GetObservationsContainedInPlace(ctx context.Context, variables []string, containedInPlace *v2.ContainedInPlace) ([]*Observation, error) { - if useNormalizedSchema(ctx) { - logNormalizedInvocation("GetObservationsContainedInPlace", - "num_variables", len(variables), - "ancestor", containedInPlace.Ancestor, - "child_place_type", containedInPlace.ChildPlaceType, - ) - return s.normalized.GetObservationsContainedInPlace(ctx, variables, containedInPlace) - } - return s.SpannerClient.GetObservationsContainedInPlace(ctx, variables, containedInPlace) -} - -// GetSdmxObservations overrides the embedded client's GetSdmxObservations. -// SDMX is only supported on the normalized schema, so it always delegates to the normalized client. -func (s *schemaSelectorClient) GetSdmxObservations(ctx context.Context, req *pb.SdmxDataQuery) (*pb.SdmxDataResult, error) { - logNormalizedInvocation("GetSdmxObservations", - "query", req, - ) - return s.normalized.GetSdmxObservations(ctx, req) -} - -// NewSchemaSelectorClient creates a new SpannerClient that dispatches calls to either default or normalized client. -func NewSchemaSelectorClient(baseClient SpannerClient, normalizedClient NormalizedObservationProvider) (SpannerClient, error) { - return &schemaSelectorClient{ - SpannerClient: baseClient, - normalized: normalizedClient, - }, nil -} - -// useNormalizedSchema checks whether to use the normalized Spanner schema based on request header. -func useNormalizedSchema(ctx context.Context) bool { - if md, ok := metadata.FromIncomingContext(ctx); ok { - headers := md.Get(util.XUseNormalizedSchema) - return len(headers) > 0 && headers[0] == "true" - } - return false -} - -// shouldLogSQL checks whether to log the full interpolated SQL query based on request header. -func shouldLogSQL(ctx context.Context) bool { - if md, ok := metadata.FromIncomingContext(ctx); ok { - headers := md.Get(util.XLogSQL) - return len(headers) > 0 && headers[0] == "true" - } - return false -} - -// getSchemaName returns the name of the schema being used based on context. -func getSchemaName(ctx context.Context) string { - if useNormalizedSchema(ctx) { - return "Normalized" - } - return "Legacy" -} - -// logNormalizedInvocation logs that the normalized schema was invoked for a method with custom arguments. -func logNormalizedInvocation(methodName string, args ...any) { - fullArgs := append([]any{"method", methodName}, args...) - slog.Info("Invoking normalized Spanner schema", fullArgs...) -} diff --git a/internal/server/spanner/executor.go b/internal/server/spanner/executor.go index dba92580e..28b8cca1a 100644 --- a/internal/server/spanner/executor.go +++ b/internal/server/spanner/executor.go @@ -27,7 +27,9 @@ import ( "cloud.google.com/go/spanner" "github.com/datacommonsorg/mixer/internal/metrics" + "github.com/datacommonsorg/mixer/internal/util" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/api/iterator" @@ -350,3 +352,20 @@ func processCacheRows[T proto.Message](iter *spanner.RowIterator, newProto func( func isTimeoutError(err error) bool { return spanner.ErrCode(err) == codes.DeadlineExceeded || errors.Is(err, context.DeadlineExceeded) } + +// shouldLogSQL checks whether to log the full interpolated SQL query based on request header. +func shouldLogSQL(ctx context.Context) bool { + if md, ok := metadata.FromIncomingContext(ctx); ok { + headers := md.Get(util.XLogSQL) + return len(headers) > 0 && headers[0] == "true" + } + return false +} + +// getSchemaName returns the name of the schema being used based on context. +func getSchemaName(ctx context.Context) string { + if useNormalizedSchema(ctx) { + return "Normalized" + } + return "Legacy" +} diff --git a/internal/server/spanner/golden/query_normalized_test.go b/internal/server/spanner/golden/query_normalized_test.go index 584e0cd91..1f72e5b2e 100644 --- a/internal/server/spanner/golden/query_normalized_test.go +++ b/internal/server/spanner/golden/query_normalized_test.go @@ -28,10 +28,7 @@ func TestNormalizedGetObservations(t *testing.T) { client := test.NewNormalizedSpannerClient(t) t.Parallel() - nc, err := spanner.NewNormalizedClient(client) - if err != nil { - t.Fatalf("NewNormalizedClient failed: %v", err) - } + nc := spanner.NewNormalizedClient(client) for _, c := range normalizedObservationsTestCases { goldenFile := c.golden + ".json" @@ -51,10 +48,7 @@ func TestNormalizedCheckVariableExistence(t *testing.T) { client := test.NewNormalizedSpannerClient(t) t.Parallel() - nc, err := spanner.NewNormalizedClient(client) - if err != nil { - t.Fatalf("NewNormalizedClient failed: %v", err) - } + nc := spanner.NewNormalizedClient(client) for _, c := range checkVariableExistenceTestCases { goldenFile := c.golden + ".json" @@ -85,10 +79,7 @@ func TestNormalizedGetObservationsContainedInPlace(t *testing.T) { client := test.NewNormalizedSpannerClient(t) t.Parallel() - nc, err := spanner.NewNormalizedClient(client) - if err != nil { - t.Fatalf("NewNormalizedClient failed: %v", err) - } + nc := spanner.NewNormalizedClient(client) for _, c := range getObservationsContainedInPlaceTestCases { goldenFile := c.golden + ".json" @@ -111,10 +102,7 @@ func TestNormalizedGetSdmxObservations(t *testing.T) { client := test.NewNormalizedSpannerClient(t) t.Parallel() - nc, err := spanner.NewNormalizedClient(client) - if err != nil { - t.Fatalf("NewNormalizedClient failed: %v", err) - } + nc := spanner.NewNormalizedClient(client) for _, c := range sdmxObservationsTestCases { goldenFile := c.golden + ".json" diff --git a/internal/server/spanner/query.go b/internal/server/spanner/query.go index d3d1b9653..95f44b48d 100644 --- a/internal/server/spanner/query.go +++ b/internal/server/spanner/query.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Queries executed by the SpannerClient. +// This file implements the methods for standardClient (Default/Legacy Schema). package spanner import ( @@ -73,7 +73,7 @@ const ( ) // GetNodeProps retrieves node properties from Spanner given a list of IDs and a direction and returns a map. -func (sc *defaultSpannerClient) GetNodeProps(ctx context.Context, ids []string, out bool) (map[string][]*Property, error) { +func (sc *standardSpannerClient) GetNodeProps(ctx context.Context, ids []string, out bool) (map[string][]*Property, error) { props := map[string][]*Property{} if len(ids) == 0 { return props, nil @@ -102,7 +102,7 @@ func (sc *defaultSpannerClient) GetNodeProps(ctx context.Context, ids []string, } // GetNodeEdgesByID retrieves node edges from Spanner and returns a map of subjectID to Edges. -func (sc *defaultSpannerClient) GetNodeEdgesByID(ctx context.Context, ids []string, arc *v2.Arc, pageSize, offset int) (map[string][]*Edge, error) { +func (sc *standardSpannerClient) GetNodeEdgesByID(ctx context.Context, ids []string, arc *v2.Arc, pageSize, offset int) (map[string][]*Edge, error) { edges := make(map[string][]*Edge) if len(ids) == 0 { return edges, nil @@ -131,7 +131,7 @@ func (sc *defaultSpannerClient) GetNodeEdgesByID(ctx context.Context, ids []stri } // GetObservations retrieves observations from Spanner given a list of variables and entities. -func (sc *defaultSpannerClient) GetObservations(ctx context.Context, variables []string, entities []string) ([]*Observation, error) { +func (sc *standardSpannerClient) GetObservations(ctx context.Context, variables []string, entities []string) ([]*Observation, error) { var observations []*Observation if len(entities) == 0 { return nil, fmt.Errorf("entity must be specified") @@ -157,7 +157,7 @@ func (sc *defaultSpannerClient) GetObservations(ctx context.Context, variables [ // CheckVariableExistence checks for the existence of observations for the given variables and entities. // Returns a slice of rows, where each row contains [variable, entity] that has at least one observation. -func (sc *defaultSpannerClient) CheckVariableExistence(ctx context.Context, variables []string, entities []string) ([][]string, error) { +func (sc *standardSpannerClient) CheckVariableExistence(ctx context.Context, variables []string, entities []string) ([][]string, error) { stmt, err := FilterStatVarsByEntityQuery(variables, entities) if err != nil { return nil, err @@ -166,7 +166,7 @@ func (sc *defaultSpannerClient) CheckVariableExistence(ctx context.Context, vari } // GetObservationsContainedInPlace retrieves observations from Spanner given a list of variables and an entity expression. -func (sc *defaultSpannerClient) GetObservationsContainedInPlace(ctx context.Context, variables []string, containedInPlace *v2.ContainedInPlace) ([]*Observation, error) { +func (sc *standardSpannerClient) GetObservationsContainedInPlace(ctx context.Context, variables []string, containedInPlace *v2.ContainedInPlace) ([]*Observation, error) { var observations []*Observation if len(variables) == 0 || containedInPlace == nil { return observations, nil @@ -193,7 +193,7 @@ func (sc *defaultSpannerClient) GetObservationsContainedInPlace(ctx context.Cont // SearchNodes searches nodes in the graph based on the query and optionally the types. // If the types array is empty, it searches across nodes of all types. // A maximum of 100 results are returned. -func (sc *defaultSpannerClient) SearchNodes(ctx context.Context, query string, types []string) ([]*SearchNode, error) { +func (sc *standardSpannerClient) SearchNodes(ctx context.Context, query string, types []string) ([]*SearchNode, error) { var nodes []*SearchNode if query == "" { return nodes, nil @@ -218,7 +218,7 @@ func (sc *defaultSpannerClient) SearchNodes(ctx context.Context, query string, t } // ResolveByID fetches ID resolution candidates for a list of input nodes and in and out properties and returns a map of node to candidates. -func (sc *defaultSpannerClient) ResolveByID(ctx context.Context, nodes []string, in, out string) (map[string][]string, error) { +func (sc *standardSpannerClient) ResolveByID(ctx context.Context, nodes []string, in, out string) (map[string][]string, error) { nodeToCandidates := make(map[string][]string) if len(nodes) == 0 { return nodeToCandidates, nil @@ -252,7 +252,7 @@ func (sc *defaultSpannerClient) ResolveByID(ctx context.Context, nodes []string, } // GetEventCollectionDate retrieves event collection dates from Spanner. -func (sc *defaultSpannerClient) GetEventCollectionDate(ctx context.Context, placeID, eventType string) ([]string, error) { +func (sc *standardSpannerClient) GetEventCollectionDate(ctx context.Context, placeID, eventType string) ([]string, error) { stmt := GetEventCollectionDateQuery(placeID, eventType) rows, err := sc.exec.queryDynamic(ctx, *stmt) if err != nil { @@ -269,7 +269,7 @@ func (sc *defaultSpannerClient) GetEventCollectionDate(ctx context.Context, plac } // GetEventCollection retrieves and filters event collection from Spanner. -func (sc *defaultSpannerClient) GetEventCollection(ctx context.Context, req *pbv1.EventCollectionRequest) (*pbv1.EventCollection, error) { +func (sc *standardSpannerClient) GetEventCollection(ctx context.Context, req *pbv1.EventCollectionRequest) (*pbv1.EventCollection, error) { // Get event DCIDs eventRows, err := sc.GetEventCollectionDcids(ctx, req.AffectedPlaceDcid, req.EventType, req.Date) if err != nil { @@ -301,7 +301,7 @@ func (sc *defaultSpannerClient) GetEventCollection(ctx context.Context, req *pbv return res, nil } -func (sc *defaultSpannerClient) populateProvenanceInfo(ctx context.Context, res *pbv1.EventCollection) error { +func (sc *standardSpannerClient) populateProvenanceInfo(ctx context.Context, res *pbv1.EventCollection) error { provDcids := []string{} seen := map[string]bool{} for _, event := range res.Events { @@ -506,7 +506,7 @@ func keepEvent(event *pbv1.EventCollection_Event, req *pbv1.EventCollectionReque } // GetEventCollectionDcids retrieves event DCIDs from Spanner. -func (sc *defaultSpannerClient) GetEventCollectionDcids(ctx context.Context, placeID, eventType, date string) ([]EventIdWithMagnitudeDcid, error) { +func (sc *standardSpannerClient) GetEventCollectionDcids(ctx context.Context, placeID, eventType, date string) ([]EventIdWithMagnitudeDcid, error) { stmt := GetEventCollectionDcidsQuery(placeID, eventType, date) rows, err := sc.exec.queryDynamic(ctx, *stmt) if err != nil { @@ -585,7 +585,7 @@ func parseAndSortEvents(rows []EventIdWithMagnitudeDcid, eventType string) []str return res } -func (sc *defaultSpannerClient) Sparql(ctx context.Context, nodes []types.Node, queries []*types.Query, opts *types.QueryOptions) ([][]string, error) { +func (sc *standardSpannerClient) Sparql(ctx context.Context, nodes []types.Node, queries []*types.Query, opts *types.QueryOptions) ([][]string, error) { query, err := SparqlQuery(nodes, queries, opts) if err != nil { return nil, fmt.Errorf("error building sparql query: %v", err) @@ -594,7 +594,7 @@ func (sc *defaultSpannerClient) Sparql(ctx context.Context, nodes []types.Node, return sc.exec.queryDynamic(ctx, *query) } -func (sc *defaultSpannerClient) GetProvenanceSummary(ctx context.Context, variables []string) (map[string]map[string]*pb.StatVarSummary_ProvenanceSummary, error) { +func (sc *standardSpannerClient) GetProvenanceSummary(ctx context.Context, variables []string) (map[string]map[string]*pb.StatVarSummary_ProvenanceSummary, error) { if len(variables) == 0 { return map[string]map[string]*pb.StatVarSummary_ProvenanceSummary{}, nil @@ -616,7 +616,7 @@ func (sc *defaultSpannerClient) GetProvenanceSummary(ctx context.Context, variab } // GetTermEmbeddingQuery retrieves embeddings from Spanner for a given query. -func (sc *defaultSpannerClient) GetTermEmbeddingQuery(ctx context.Context, modelName, searchLabel, taskType string) ([]float64, error) { +func (sc *standardSpannerClient) GetTermEmbeddingQuery(ctx context.Context, modelName, searchLabel, taskType string) ([]float64, error) { embeddings := []float64{} err := sc.exec.executeQuery(ctx, *GetTermEmbeddingQuery(modelName, searchLabel, taskType), func(iter *spanner.RowIterator) error { row, err := iter.Next() @@ -632,7 +632,7 @@ func (sc *defaultSpannerClient) GetTermEmbeddingQuery(ctx context.Context, model } // FilterNodesByTypes filters a list of nodes by types and returns a map of node to matched types. -func (sc *defaultSpannerClient) FilterNodesByTypes(ctx context.Context, nodes []string, typeFilters []string) (map[string][]string, error) { +func (sc *standardSpannerClient) FilterNodesByTypes(ctx context.Context, nodes []string, typeFilters []string) (map[string][]string, error) { if len(nodes) == 0 { return map[string][]string{}, nil } @@ -662,7 +662,7 @@ func (sc *defaultSpannerClient) FilterNodesByTypes(ctx context.Context, nodes [] } // VectorSearchQuery performs vector similarity search in Spanner. -func (sc *defaultSpannerClient) VectorSearchQuery(ctx context.Context, tableName string, limit int, embeddings []float64, numLeaves int, threshold float64, nodeTypes []string) ([]*VectorSearchResult, error) { +func (sc *standardSpannerClient) VectorSearchQuery(ctx context.Context, tableName string, limit int, embeddings []float64, numLeaves int, threshold float64, nodeTypes []string) ([]*VectorSearchResult, error) { var results []*VectorSearchResult err := sc.exec.queryStructs( ctx, @@ -679,7 +679,7 @@ func (sc *defaultSpannerClient) VectorSearchQuery(ctx context.Context, tableName } // GetStatVarGroupNode fetches StatVarGroupNode info from Spanner. -func (sc *defaultSpannerClient) GetStatVarGroupNode(ctx context.Context, nodes []string, includeDefinitions bool) ([]*StatVarGroupNode, error) { +func (sc *standardSpannerClient) GetStatVarGroupNode(ctx context.Context, nodes []string, includeDefinitions bool) ([]*StatVarGroupNode, error) { var svgNodes []*StatVarGroupNode if len(nodes) == 0 { return svgNodes, nil @@ -703,7 +703,7 @@ func (sc *defaultSpannerClient) GetStatVarGroupNode(ctx context.Context, nodes [ } // GetFilteredStatVarGroupNode fetches filtered StatVarGroupNode info from Spanner. -func (sc *defaultSpannerClient) GetFilteredStatVarGroupNode(ctx context.Context, nodes []string, constrainedPlaces []string, constrainedImport string, numEntitiesExistence int, includeDefinitions bool) (map[string]*FilteredStatVarGroupNode, error) { +func (sc *standardSpannerClient) GetFilteredStatVarGroupNode(ctx context.Context, nodes []string, constrainedPlaces []string, constrainedImport string, numEntitiesExistence int, includeDefinitions bool) (map[string]*FilteredStatVarGroupNode, error) { response := map[string]*FilteredStatVarGroupNode{} errGroup, errCtx := errgroup.WithContext(ctx) errGroup.SetLimit(maxConcurrentFilteredSVGGoroutines) // Limit the number of concurrent goroutines to avoid overwhelming Spanner with too many requests. @@ -739,7 +739,7 @@ func (sc *defaultSpannerClient) GetFilteredStatVarGroupNode(ctx context.Context, } // getSingleFilteredStatVarGroupNode fetches the relevant info to build a single filtered StatVarGroupNode from Spanner. -func (sc *defaultSpannerClient) getSingleFilteredStatVarGroupNode(ctx context.Context, node string, constrainedPlaces []string, constrainedImport string, numEntitiesExistence int, includeDefinitions bool) (*FilteredStatVarGroupNode, error) { +func (sc *standardSpannerClient) getSingleFilteredStatVarGroupNode(ctx context.Context, node string, constrainedPlaces []string, constrainedImport string, numEntitiesExistence int, includeDefinitions bool) (*FilteredStatVarGroupNode, error) { filteredStatVarGroupNode := &FilteredStatVarGroupNode{} errGroup, errCtx := errgroup.WithContext(ctx) svgChildChan := make(chan []*SVGChild, 1) @@ -819,7 +819,7 @@ func (sc *defaultSpannerClient) getSingleFilteredStatVarGroupNode(ctx context.Co } // GetFilteredTopic fetches the relevant info to build a filtered Topic response from Spanner. -func (sc *defaultSpannerClient) GetFilteredTopic(ctx context.Context, nodes []string, constrainedPlaces []string, constrainedImport string, numEntitiesExistence int) (map[string]int, error) { +func (sc *standardSpannerClient) GetFilteredTopic(ctx context.Context, nodes []string, constrainedPlaces []string, constrainedImport string, numEntitiesExistence int) (map[string]int, error) { counts := make(map[string]int, len(nodes)) for _, node := range nodes { counts[node] = 0 diff --git a/internal/server/spanner/query_normalized.go b/internal/server/spanner/query_normalized.go index 1f4ce6b24..c2b771357 100644 --- a/internal/server/spanner/query_normalized.go +++ b/internal/server/spanner/query_normalized.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// This file implements the methods for normalizedClient (Normalized Schema). package spanner import ( @@ -21,15 +22,29 @@ import ( pb "github.com/datacommonsorg/mixer/internal/proto" v2 "github.com/datacommonsorg/mixer/internal/server/v2" + "github.com/datacommonsorg/mixer/internal/util" + "google.golang.org/grpc/metadata" ) +func useNormalizedSchema(ctx context.Context) bool { + if md, ok := metadata.FromIncomingContext(ctx); ok { + headers := md.Get(util.XUseNormalizedSchema) + return len(headers) > 0 && headers[0] == "true" + } + return false +} + // GetObservations retrieves observations from Spanner given a list of variables and entities // using the normalized schema. -func (nc *normalizedClient) GetObservations(ctx context.Context, variables []string, entities []string) ([]*Observation, error) { +func (nc *normalizedSchemaClient) GetObservations(ctx context.Context, variables []string, entities []string) ([]*Observation, error) { if len(entities) == 0 { return nil, fmt.Errorf("entity must be specified") } + if !useNormalizedSchema(ctx) { + return nc.SpannerClient.GetObservations(ctx, variables, entities) + } + rawObs, err := nc.fetchRawObservations(ctx, variables, entities) if err != nil { return nil, err @@ -43,7 +58,7 @@ func (nc *normalizedClient) GetObservations(ctx context.Context, variables []str } // fetchRawObservations fetches data from TimeSeries and StatVarObservation tables. -func (nc *normalizedClient) fetchRawObservations(ctx context.Context, variables []string, entities []string) ([]*rawObservation, error) { +func (nc *normalizedSchemaClient) fetchRawObservations(ctx context.Context, variables []string, entities []string) ([]*rawObservation, error) { stmt := GetNormalizedObservationsQuery(variables, entities) var rawObs []*rawObservation @@ -90,7 +105,11 @@ func reconstructObservations(rawObs []*rawObservation) []*Observation { } // CheckVariableExistence checks which variables exist for which entities using the normalized schema. -func (nc *normalizedClient) CheckVariableExistence(ctx context.Context, variables []string, entities []string) ([][]string, error) { +func (nc *normalizedSchemaClient) CheckVariableExistence(ctx context.Context, variables []string, entities []string) ([][]string, error) { + if !useNormalizedSchema(ctx) { + return nc.SpannerClient.CheckVariableExistence(ctx, variables, entities) + } + stmt, err := GetNormalizedStatVarsByEntityQuery(variables, entities) if err != nil { return nil, err @@ -100,11 +119,15 @@ func (nc *normalizedClient) CheckVariableExistence(ctx context.Context, variable // GetObservationsContainedInPlace retrieves observations for entities contained in a place // using the normalized schema. -func (nc *normalizedClient) GetObservationsContainedInPlace(ctx context.Context, variables []string, containedInPlace *v2.ContainedInPlace) ([]*Observation, error) { +func (nc *normalizedSchemaClient) GetObservationsContainedInPlace(ctx context.Context, variables []string, containedInPlace *v2.ContainedInPlace) ([]*Observation, error) { if containedInPlace == nil { return nil, fmt.Errorf("containedInPlace must be specified") } + if !useNormalizedSchema(ctx) { + return nc.SpannerClient.GetObservationsContainedInPlace(ctx, variables, containedInPlace) + } + rawObs, err := nc.fetchRawObservationsContainedInPlace(ctx, variables, containedInPlace) if err != nil { return nil, err @@ -118,7 +141,7 @@ func (nc *normalizedClient) GetObservationsContainedInPlace(ctx context.Context, } // fetchRawObservationsContainedInPlace fetches data from Graph, TimeSeries and StatVarObservation tables. -func (nc *normalizedClient) fetchRawObservationsContainedInPlace(ctx context.Context, variables []string, containedInPlace *v2.ContainedInPlace) ([]*rawObservation, error) { +func (nc *normalizedSchemaClient) fetchRawObservationsContainedInPlace(ctx context.Context, variables []string, containedInPlace *v2.ContainedInPlace) ([]*rawObservation, error) { stmt := GetNormalizedObservationsContainedInPlaceQuery(variables, containedInPlace) var rawObs []*rawObservation @@ -139,7 +162,7 @@ var facetAttributes = map[string]bool{ // GetSdmxObservations retrieves observations from Spanner given a list of constraints // using the normalized schema and relational division. -func (nc *normalizedClient) GetSdmxObservations(ctx context.Context, req *pb.SdmxDataQuery) (*pb.SdmxDataResult, error) { +func (nc *normalizedSchemaClient) GetSdmxObservations(ctx context.Context, req *pb.SdmxDataQuery) (*pb.SdmxDataResult, error) { stmt := GetSdmxObservationsQuery(req) var rawObs []*rawObservation From fe059219baf5d90b8d52df92414956f150ca6f3a Mon Sep 17 00:00:00 2001 From: Christie Ellks Date: Thu, 14 May 2026 22:23:10 -0700 Subject: [PATCH 3/8] no panic --- internal/server/spanner/client.go | 13 ++++++++---- .../spanner/golden/query_normalized_test.go | 20 +++++++++++++++---- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/internal/server/spanner/client.go b/internal/server/spanner/client.go index 7617dfa2a..209a0e4d2 100644 --- a/internal/server/spanner/client.go +++ b/internal/server/spanner/client.go @@ -74,15 +74,17 @@ type normalizedSchemaClient struct { } // NewNormalizedClient creates a new normalizedSchemaClient. -func NewNormalizedClient(client SpannerClient) *normalizedSchemaClient { +func NewNormalizedClient(client SpannerClient) (*normalizedSchemaClient, error) { sc, ok := client.(*standardSpannerClient) if !ok { - panic(fmt.Sprintf("NewNormalizedClient: expected *standardSpannerClient, got %T", client)) + err := fmt.Errorf("NewNormalizedClient: expected *standardSpannerClient, got %T", client) + slog.Error("Failed to create normalized client", "error", err) + return nil, err } return &normalizedSchemaClient{ SpannerClient: client, exec: sc.exec, - } + }, nil } // Force compiler that all methods required by the interface are implemented by clients @@ -124,7 +126,10 @@ func NewSpannerClient(ctx context.Context, spannerConfigYaml, databaseOverride s } defaultClient := newStandardSpannerClient(exec) - normalizedSchemaClient := NewNormalizedClient(defaultClient) + normalizedSchemaClient, err := NewNormalizedClient(defaultClient) + if err != nil { + return nil, err + } return normalizedSchemaClient, nil } diff --git a/internal/server/spanner/golden/query_normalized_test.go b/internal/server/spanner/golden/query_normalized_test.go index 1f72e5b2e..584e0cd91 100644 --- a/internal/server/spanner/golden/query_normalized_test.go +++ b/internal/server/spanner/golden/query_normalized_test.go @@ -28,7 +28,10 @@ func TestNormalizedGetObservations(t *testing.T) { client := test.NewNormalizedSpannerClient(t) t.Parallel() - nc := spanner.NewNormalizedClient(client) + nc, err := spanner.NewNormalizedClient(client) + if err != nil { + t.Fatalf("NewNormalizedClient failed: %v", err) + } for _, c := range normalizedObservationsTestCases { goldenFile := c.golden + ".json" @@ -48,7 +51,10 @@ func TestNormalizedCheckVariableExistence(t *testing.T) { client := test.NewNormalizedSpannerClient(t) t.Parallel() - nc := spanner.NewNormalizedClient(client) + nc, err := spanner.NewNormalizedClient(client) + if err != nil { + t.Fatalf("NewNormalizedClient failed: %v", err) + } for _, c := range checkVariableExistenceTestCases { goldenFile := c.golden + ".json" @@ -79,7 +85,10 @@ func TestNormalizedGetObservationsContainedInPlace(t *testing.T) { client := test.NewNormalizedSpannerClient(t) t.Parallel() - nc := spanner.NewNormalizedClient(client) + nc, err := spanner.NewNormalizedClient(client) + if err != nil { + t.Fatalf("NewNormalizedClient failed: %v", err) + } for _, c := range getObservationsContainedInPlaceTestCases { goldenFile := c.golden + ".json" @@ -102,7 +111,10 @@ func TestNormalizedGetSdmxObservations(t *testing.T) { client := test.NewNormalizedSpannerClient(t) t.Parallel() - nc := spanner.NewNormalizedClient(client) + nc, err := spanner.NewNormalizedClient(client) + if err != nil { + t.Fatalf("NewNormalizedClient failed: %v", err) + } for _, c := range sdmxObservationsTestCases { goldenFile := c.golden + ".json" From 99d20f673b7685e3d5dd6695324947236d376b50 Mon Sep 17 00:00:00 2001 From: Christie Ellks Date: Thu, 14 May 2026 22:27:21 -0700 Subject: [PATCH 4/8] add back logging --- internal/server/spanner/query_normalized.go | 27 +++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/internal/server/spanner/query_normalized.go b/internal/server/spanner/query_normalized.go index c2b771357..b2589576c 100644 --- a/internal/server/spanner/query_normalized.go +++ b/internal/server/spanner/query_normalized.go @@ -18,6 +18,7 @@ package spanner import ( "context" "fmt" + "log/slog" "strconv" pb "github.com/datacommonsorg/mixer/internal/proto" @@ -34,6 +35,12 @@ func useNormalizedSchema(ctx context.Context) bool { return false } +// logNormalizedInvocation logs that the normalized schema was invoked for a method with custom arguments. +func logNormalizedInvocation(methodName string, args ...any) { + fullArgs := append([]any{"method", methodName}, args...) + slog.Info("Invoking normalized Spanner schema", fullArgs...) +} + // GetObservations retrieves observations from Spanner given a list of variables and entities // using the normalized schema. func (nc *normalizedSchemaClient) GetObservations(ctx context.Context, variables []string, entities []string) ([]*Observation, error) { @@ -45,6 +52,11 @@ func (nc *normalizedSchemaClient) GetObservations(ctx context.Context, variables return nc.SpannerClient.GetObservations(ctx, variables, entities) } + logNormalizedInvocation("GetObservations", + "num_variables", len(variables), + "num_entities", len(entities), + ) + rawObs, err := nc.fetchRawObservations(ctx, variables, entities) if err != nil { return nil, err @@ -110,6 +122,11 @@ func (nc *normalizedSchemaClient) CheckVariableExistence(ctx context.Context, va return nc.SpannerClient.CheckVariableExistence(ctx, variables, entities) } + logNormalizedInvocation("CheckVariableExistence", + "num_variables", len(variables), + "num_entities", len(entities), + ) + stmt, err := GetNormalizedStatVarsByEntityQuery(variables, entities) if err != nil { return nil, err @@ -128,6 +145,12 @@ func (nc *normalizedSchemaClient) GetObservationsContainedInPlace(ctx context.Co return nc.SpannerClient.GetObservationsContainedInPlace(ctx, variables, containedInPlace) } + logNormalizedInvocation("GetObservationsContainedInPlace", + "num_variables", len(variables), + "ancestor", containedInPlace.Ancestor, + "child_place_type", containedInPlace.ChildPlaceType, + ) + rawObs, err := nc.fetchRawObservationsContainedInPlace(ctx, variables, containedInPlace) if err != nil { return nil, err @@ -163,6 +186,10 @@ var facetAttributes = map[string]bool{ // GetSdmxObservations retrieves observations from Spanner given a list of constraints // using the normalized schema and relational division. func (nc *normalizedSchemaClient) GetSdmxObservations(ctx context.Context, req *pb.SdmxDataQuery) (*pb.SdmxDataResult, error) { + logNormalizedInvocation("GetSdmxObservations", + "query", req, + ) + stmt := GetSdmxObservationsQuery(req) var rawObs []*rawObservation From 588754619222f87112a232169f032098f44c9175 Mon Sep 17 00:00:00 2001 From: Christie Ellks Date: Fri, 15 May 2026 11:34:32 -0700 Subject: [PATCH 5/8] selector and embed --- internal/server/spanner/client.go | 126 +++++++++++------- .../spanner/{executor.go => connector.go} | 28 ++-- .../spanner/golden/query_normalized_test.go | 20 +-- ...ery_normalized.go => normalized_client.go} | 85 +++++------- .../spanner/{query.go => standard_client.go} | 60 +++++---- internal/server/spanner/timestamp_test.go | 4 +- 6 files changed, 170 insertions(+), 153 deletions(-) rename internal/server/spanner/{executor.go => connector.go} (93%) rename internal/server/spanner/{query_normalized.go => normalized_client.go} (78%) rename internal/server/spanner/{query.go => standard_client.go} (95%) diff --git a/internal/server/spanner/client.go b/internal/server/spanner/client.go index 209a0e4d2..b8620ee31 100644 --- a/internal/server/spanner/client.go +++ b/internal/server/spanner/client.go @@ -25,8 +25,8 @@ import ( pbv1 "github.com/datacommonsorg/mixer/internal/proto/v1" v2 "github.com/datacommonsorg/mixer/internal/server/v2" "github.com/datacommonsorg/mixer/internal/translator/types" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "github.com/datacommonsorg/mixer/internal/util" + "google.golang.org/grpc/metadata" "gopkg.in/yaml.v3" ) @@ -56,40 +56,85 @@ type SpannerClient interface { Close() } -// standardSpannerClient encapsulates the Spanner client that directly interacts with the Spanner database. -type standardSpannerClient struct { - exec *SpannerExecutor + + +func useNormalizedSchema(ctx context.Context) bool { + if md, ok := metadata.FromIncomingContext(ctx); ok { + headers := md.Get(util.XUseNormalizedSchema) + return len(headers) > 0 && headers[0] == "true" + } + return false } -// newStandardSpannerClient creates a new standardSpannerClient. -func newStandardSpannerClient(exec *SpannerExecutor) *standardSpannerClient { - return &standardSpannerClient{exec: exec} +// selectorClient dispatches calls to either default or normalized client based on request headers. +// It serves as the main entry point for the Spanner client, centralizing routing concerns. +// +// DESIGN NOTE: This client embeds the standard client (SpannerClient) to handle automatic +// fallback for methods that do not have a specialized normalized implementation. +// For methods that DO have a specialized implementation (like GetObservations), it explicitly +// checks the header and routes accordingly. It does NOT rely on the normalized client's +// internal fallback for general request routing, ensuring that the standard path remains +// the explicit default. +type selectorClient struct { + SpannerClient // Embeds the default client + normalized *normalizedSchemaClient } -// normalizedSchemaClient encapsulates the Spanner client for the normalized schema. -// It embeds SpannerClient to inherit default behavior and only overrides specific methods. -type normalizedSchemaClient struct { - SpannerClient // Embeds standardSpannerClient - exec *SpannerExecutor +// logNormalizedInvocation logs that the normalized schema was invoked for a method with custom arguments. +func logNormalizedInvocation(methodName string, args ...any) { + fullArgs := append([]any{"method", methodName}, args...) + slog.Info("Invoking normalized Spanner schema", fullArgs...) } -// NewNormalizedClient creates a new normalizedSchemaClient. -func NewNormalizedClient(client SpannerClient) (*normalizedSchemaClient, error) { - sc, ok := client.(*standardSpannerClient) - if !ok { - err := fmt.Errorf("NewNormalizedClient: expected *standardSpannerClient, got %T", client) - slog.Error("Failed to create normalized client", "error", err) - return nil, err +// GetObservations overrides the embedded client's GetObservations to dispatch based on schema selection. +func (s *selectorClient) GetObservations(ctx context.Context, variables []string, entities []string) ([]*Observation, error) { + if useNormalizedSchema(ctx) { + logNormalizedInvocation("GetObservations", + "num_variables", len(variables), + "num_entities", len(entities), + ) + return s.normalized.GetObservations(ctx, variables, entities) } - return &normalizedSchemaClient{ - SpannerClient: client, - exec: sc.exec, - }, nil + return s.SpannerClient.GetObservations(ctx, variables, entities) +} + +// CheckVariableExistence overrides the embedded client's CheckVariableExistence to dispatch based on schema selection. +func (s *selectorClient) CheckVariableExistence(ctx context.Context, variables []string, entities []string) ([][]string, error) { + if useNormalizedSchema(ctx) { + logNormalizedInvocation("CheckVariableExistence", + "num_variables", len(variables), + "num_entities", len(entities), + ) + return s.normalized.CheckVariableExistence(ctx, variables, entities) + } + return s.SpannerClient.CheckVariableExistence(ctx, variables, entities) +} + +// GetObservationsContainedInPlace overrides the embedded client's GetObservationsContainedInPlace to dispatch based on schema selection. +func (s *selectorClient) GetObservationsContainedInPlace(ctx context.Context, variables []string, containedInPlace *v2.ContainedInPlace) ([]*Observation, error) { + if useNormalizedSchema(ctx) { + logNormalizedInvocation("GetObservationsContainedInPlace", + "num_variables", len(variables), + "ancestor", containedInPlace.Ancestor, + "child_place_type", containedInPlace.ChildPlaceType, + ) + return s.normalized.GetObservationsContainedInPlace(ctx, variables, containedInPlace) + } + return s.SpannerClient.GetObservationsContainedInPlace(ctx, variables, containedInPlace) +} + +// GetSdmxObservations overrides the embedded client's GetSdmxObservations. +// SDMX is only supported on the normalized schema, so it always delegates to the normalized client. +func (s *selectorClient) GetSdmxObservations(ctx context.Context, req *pb.SdmxDataQuery) (*pb.SdmxDataResult, error) { + logNormalizedInvocation("GetSdmxObservations", + "query", req, + ) + return s.normalized.GetSdmxObservations(ctx, req) } // Force compiler that all methods required by the interface are implemented by clients var _ SpannerClient = (*standardSpannerClient)(nil) -var _ SpannerClient = (*normalizedSchemaClient)(nil) +var _ SpannerClient = (*selectorClient)(nil) // NewRawSpannerClient creates a new SpannerClient without the schema selector. // This is intended for testing and internal use where a direct client is needed. @@ -102,7 +147,7 @@ func NewRawSpannerClient(ctx context.Context, spannerConfigYaml, databaseOverrid if err != nil { return nil, fmt.Errorf("failed to create standardSpannerClient: %w", err) } - exec, err := NewSpannerExecutor(client) + exec, err := NewSpannerConnector(client) if err != nil { return nil, fmt.Errorf("failed to create standardSpannerClient: %w", err) } @@ -120,18 +165,18 @@ func NewSpannerClient(ctx context.Context, spannerConfigYaml, databaseOverride s if err != nil { return nil, err } - exec, err := NewSpannerExecutor(client) + exec, err := NewSpannerConnector(client) if err != nil { return nil, err } defaultClient := newStandardSpannerClient(exec) - normalizedSchemaClient, err := NewNormalizedClient(defaultClient) - if err != nil { - return nil, err - } + normalizedClient := NewNormalizedClient(defaultClient) - return normalizedSchemaClient, nil + return &selectorClient{ + SpannerClient: defaultClient, + normalized: normalizedClient, + }, nil } // createSpannerClient creates the database name string and initializes the Spanner client. @@ -166,21 +211,4 @@ func createSpannerConfig(spannerConfigYaml, databaseOverride string) (*SpannerCo return &cfg, nil } -func (sc *standardSpannerClient) Id() string { - return sc.exec.Id() -} - -// Start starts the background goroutine to periodically fetch the timestamp. -func (sc *standardSpannerClient) Start() { - sc.exec.Start() -} -// Close closes the Spanner client and stops the background goroutine. -func (sc *standardSpannerClient) Close() { - sc.exec.Close() -} - -// GetSdmxObservations is not supported on the default client. -func (sc *standardSpannerClient) GetSdmxObservations(ctx context.Context, req *pb.SdmxDataQuery) (*pb.SdmxDataResult, error) { - return nil, status.Error(codes.Unimplemented, "SDMX queries are only supported on the normalized schema") -} diff --git a/internal/server/spanner/executor.go b/internal/server/spanner/connector.go similarity index 93% rename from internal/server/spanner/executor.go rename to internal/server/spanner/connector.go index 28b8cca1a..56274e159 100644 --- a/internal/server/spanner/executor.go +++ b/internal/server/spanner/connector.go @@ -35,8 +35,8 @@ import ( "google.golang.org/api/iterator" ) -// SpannerExecutor handles the low-level details of connecting to Spanner and executing queries. -type SpannerExecutor struct { +// SpannerConnector handles the low-level details of connecting to Spanner and executing queries. +type SpannerConnector struct { client *spanner.Client timestamp atomic.Int64 ticker Ticker @@ -47,9 +47,9 @@ type SpannerExecutor struct { updateTimestamp func(context.Context) error } -// NewSpannerExecutor creates a new SpannerExecutor. -func NewSpannerExecutor(client *spanner.Client) (*SpannerExecutor, error) { - se := &SpannerExecutor{ +// NewSpannerConnector creates a new SpannerConnector. +func NewSpannerConnector(client *spanner.Client) (*SpannerConnector, error) { + se := &SpannerConnector{ client: client, } @@ -65,12 +65,12 @@ func NewSpannerExecutor(client *spanner.Client) (*SpannerExecutor, error) { } // Id returns the database name. -func (se *SpannerExecutor) Id() string { +func (se *SpannerConnector) Id() string { return se.client.DatabaseName() } // Start starts the background goroutine to periodically fetch the timestamp. -func (se *SpannerExecutor) Start() { +func (se *SpannerConnector) Start() { se.startOnce.Do(func() { ctx, cancel := context.WithCancel(context.Background()) @@ -96,7 +96,7 @@ func (se *SpannerExecutor) Start() { } // Close closes the Spanner client and stops the background goroutine. -func (se *SpannerExecutor) Close() { +func (se *SpannerConnector) Close() { se.stopOnce.Do(func() { close(se.stopCh) se.wg.Wait() @@ -107,7 +107,7 @@ func (se *SpannerExecutor) Close() { } // fetchAndUpdateTimestamp queries Spanner and updates the timestamp. -func (se *SpannerExecutor) fetchAndUpdateTimestamp(ctx context.Context) error { +func (se *SpannerConnector) fetchAndUpdateTimestamp(ctx context.Context) error { queryCtx, cancel := context.WithTimeout(ctx, timestampPollingTimeout) defer cancel() @@ -148,7 +148,7 @@ func (se *SpannerExecutor) fetchAndUpdateTimestamp(ctx context.Context) error { return nil } -func (se *SpannerExecutor) getStalenessTimestamp() (time.Time, error) { +func (se *SpannerConnector) getStalenessTimestamp() (time.Time, error) { val := se.timestamp.Load() if val != 0 { return time.Unix(0, val).UTC(), nil @@ -157,7 +157,7 @@ func (se *SpannerExecutor) getStalenessTimestamp() (time.Time, error) { return time.Time{}, fmt.Errorf("error getting staleness timestamp") } -func (se *SpannerExecutor) executeQuery( +func (se *SpannerConnector) executeQuery( ctx context.Context, stmt spanner.Statement, handleRows func(*spanner.RowIterator) error, @@ -222,7 +222,7 @@ func (se *SpannerExecutor) executeQuery( } // queryStructs executes a query and maps the results to an input struct. -func (se *SpannerExecutor) queryStructs( +func (se *SpannerConnector) queryStructs( ctx context.Context, stmt spanner.Statement, newStruct func() interface{}, @@ -234,7 +234,7 @@ func (se *SpannerExecutor) queryStructs( } // queryDynamic executes a dynamically constructed query and returns the results as a slice of string slices. -func (se *SpannerExecutor) queryDynamic( +func (se *SpannerConnector) queryDynamic( ctx context.Context, stmt spanner.Statement, ) ([][]string, error) { @@ -250,7 +250,7 @@ func (se *SpannerExecutor) queryDynamic( // queryCache executes a query and maps the results to an input cache proto. func queryCache[T proto.Message]( ctx context.Context, - se *SpannerExecutor, + se *SpannerConnector, stmt spanner.Statement, newProto func() T, ) (map[string]map[string]T, error) { diff --git a/internal/server/spanner/golden/query_normalized_test.go b/internal/server/spanner/golden/query_normalized_test.go index 584e0cd91..1f72e5b2e 100644 --- a/internal/server/spanner/golden/query_normalized_test.go +++ b/internal/server/spanner/golden/query_normalized_test.go @@ -28,10 +28,7 @@ func TestNormalizedGetObservations(t *testing.T) { client := test.NewNormalizedSpannerClient(t) t.Parallel() - nc, err := spanner.NewNormalizedClient(client) - if err != nil { - t.Fatalf("NewNormalizedClient failed: %v", err) - } + nc := spanner.NewNormalizedClient(client) for _, c := range normalizedObservationsTestCases { goldenFile := c.golden + ".json" @@ -51,10 +48,7 @@ func TestNormalizedCheckVariableExistence(t *testing.T) { client := test.NewNormalizedSpannerClient(t) t.Parallel() - nc, err := spanner.NewNormalizedClient(client) - if err != nil { - t.Fatalf("NewNormalizedClient failed: %v", err) - } + nc := spanner.NewNormalizedClient(client) for _, c := range checkVariableExistenceTestCases { goldenFile := c.golden + ".json" @@ -85,10 +79,7 @@ func TestNormalizedGetObservationsContainedInPlace(t *testing.T) { client := test.NewNormalizedSpannerClient(t) t.Parallel() - nc, err := spanner.NewNormalizedClient(client) - if err != nil { - t.Fatalf("NewNormalizedClient failed: %v", err) - } + nc := spanner.NewNormalizedClient(client) for _, c := range getObservationsContainedInPlaceTestCases { goldenFile := c.golden + ".json" @@ -111,10 +102,7 @@ func TestNormalizedGetSdmxObservations(t *testing.T) { client := test.NewNormalizedSpannerClient(t) t.Parallel() - nc, err := spanner.NewNormalizedClient(client) - if err != nil { - t.Fatalf("NewNormalizedClient failed: %v", err) - } + nc := spanner.NewNormalizedClient(client) for _, c := range sdmxObservationsTestCases { goldenFile := c.golden + ".json" diff --git a/internal/server/spanner/query_normalized.go b/internal/server/spanner/normalized_client.go similarity index 78% rename from internal/server/spanner/query_normalized.go rename to internal/server/spanner/normalized_client.go index b2589576c..95ffca0ce 100644 --- a/internal/server/spanner/query_normalized.go +++ b/internal/server/spanner/normalized_client.go @@ -18,29 +18,48 @@ package spanner import ( "context" "fmt" - "log/slog" "strconv" pb "github.com/datacommonsorg/mixer/internal/proto" v2 "github.com/datacommonsorg/mixer/internal/server/v2" - "github.com/datacommonsorg/mixer/internal/util" - "google.golang.org/grpc/metadata" ) -func useNormalizedSchema(ctx context.Context) bool { - if md, ok := metadata.FromIncomingContext(ctx); ok { - headers := md.Get(util.XUseNormalizedSchema) - return len(headers) > 0 && headers[0] == "true" - } - return false +// normalizedSchemaClient encapsulates the Spanner client for the normalized schema. +// It implements specialized queries optimized for the normalized schema. +// +// DESIGN NOTE: This client embeds the SpannerClient interface (initialized with the +// standard client) primarily to fulfill the full interface and to provide full +// internal-facing functionality (e.g., if a normalized method needs to call a +// standard method like GetNodeProps internally). It is NOT intended to be used +// by the Selector for general request fallbacks, which are handled explicitly +// by the Selector itself. +type normalizedSchemaClient struct { + SpannerClient + conn *SpannerConnector } -// logNormalizedInvocation logs that the normalized schema was invoked for a method with custom arguments. -func logNormalizedInvocation(methodName string, args ...any) { - fullArgs := append([]any{"method", methodName}, args...) - slog.Info("Invoking normalized Spanner schema", fullArgs...) +// NewNormalizedClient creates a new normalizedSchemaClient. +func NewNormalizedClient(client SpannerClient) *normalizedSchemaClient { + var conn *SpannerConnector + if sc, ok := client.(*standardSpannerClient); ok { + conn = sc.exec + } else if sc, ok := client.(*selectorClient); ok { + if std, ok := sc.SpannerClient.(*standardSpannerClient); ok { + conn = std.exec + } + } + if conn == nil { + panic(fmt.Sprintf("NewNormalizedClient: unexpected client type %T", client)) + } + return &normalizedSchemaClient{ + SpannerClient: client, + conn: conn, + } } +// Force compiler that all methods required by the interface are implemented by clients +var _ SpannerClient = (*normalizedSchemaClient)(nil) + // GetObservations retrieves observations from Spanner given a list of variables and entities // using the normalized schema. func (nc *normalizedSchemaClient) GetObservations(ctx context.Context, variables []string, entities []string) ([]*Observation, error) { @@ -48,15 +67,6 @@ func (nc *normalizedSchemaClient) GetObservations(ctx context.Context, variables return nil, fmt.Errorf("entity must be specified") } - if !useNormalizedSchema(ctx) { - return nc.SpannerClient.GetObservations(ctx, variables, entities) - } - - logNormalizedInvocation("GetObservations", - "num_variables", len(variables), - "num_entities", len(entities), - ) - rawObs, err := nc.fetchRawObservations(ctx, variables, entities) if err != nil { return nil, err @@ -74,7 +84,7 @@ func (nc *normalizedSchemaClient) fetchRawObservations(ctx context.Context, vari stmt := GetNormalizedObservationsQuery(variables, entities) var rawObs []*rawObservation - err := nc.exec.queryStructs(ctx, *stmt, func() interface{} { return &rawObservation{} }, func(row interface{}) { + err := nc.conn.queryStructs(ctx, *stmt, func() interface{} { return &rawObservation{} }, func(row interface{}) { rawObs = append(rawObs, row.(*rawObservation)) }) return rawObs, err @@ -118,20 +128,11 @@ func reconstructObservations(rawObs []*rawObservation) []*Observation { // CheckVariableExistence checks which variables exist for which entities using the normalized schema. func (nc *normalizedSchemaClient) CheckVariableExistence(ctx context.Context, variables []string, entities []string) ([][]string, error) { - if !useNormalizedSchema(ctx) { - return nc.SpannerClient.CheckVariableExistence(ctx, variables, entities) - } - - logNormalizedInvocation("CheckVariableExistence", - "num_variables", len(variables), - "num_entities", len(entities), - ) - stmt, err := GetNormalizedStatVarsByEntityQuery(variables, entities) if err != nil { return nil, err } - return nc.exec.queryDynamic(ctx, *stmt) + return nc.conn.queryDynamic(ctx, *stmt) } // GetObservationsContainedInPlace retrieves observations for entities contained in a place @@ -141,16 +142,6 @@ func (nc *normalizedSchemaClient) GetObservationsContainedInPlace(ctx context.Co return nil, fmt.Errorf("containedInPlace must be specified") } - if !useNormalizedSchema(ctx) { - return nc.SpannerClient.GetObservationsContainedInPlace(ctx, variables, containedInPlace) - } - - logNormalizedInvocation("GetObservationsContainedInPlace", - "num_variables", len(variables), - "ancestor", containedInPlace.Ancestor, - "child_place_type", containedInPlace.ChildPlaceType, - ) - rawObs, err := nc.fetchRawObservationsContainedInPlace(ctx, variables, containedInPlace) if err != nil { return nil, err @@ -168,7 +159,7 @@ func (nc *normalizedSchemaClient) fetchRawObservationsContainedInPlace(ctx conte stmt := GetNormalizedObservationsContainedInPlaceQuery(variables, containedInPlace) var rawObs []*rawObservation - err := nc.exec.queryStructs(ctx, *stmt, func() interface{} { return &rawObservation{} }, func(row interface{}) { + err := nc.conn.queryStructs(ctx, *stmt, func() interface{} { return &rawObservation{} }, func(row interface{}) { rawObs = append(rawObs, row.(*rawObservation)) }) return rawObs, err @@ -186,14 +177,10 @@ var facetAttributes = map[string]bool{ // GetSdmxObservations retrieves observations from Spanner given a list of constraints // using the normalized schema and relational division. func (nc *normalizedSchemaClient) GetSdmxObservations(ctx context.Context, req *pb.SdmxDataQuery) (*pb.SdmxDataResult, error) { - logNormalizedInvocation("GetSdmxObservations", - "query", req, - ) - stmt := GetSdmxObservationsQuery(req) var rawObs []*rawObservation - err := nc.exec.queryStructs(ctx, *stmt, func() interface{} { return &rawObservation{} }, func(row interface{}) { + err := nc.conn.queryStructs(ctx, *stmt, func() interface{} { return &rawObservation{} }, func(row interface{}) { rawObs = append(rawObs, row.(*rawObservation)) }) if err != nil { diff --git a/internal/server/spanner/query.go b/internal/server/spanner/standard_client.go similarity index 95% rename from internal/server/spanner/query.go rename to internal/server/spanner/standard_client.go index 95f44b48d..121328f34 100644 --- a/internal/server/spanner/query.go +++ b/internal/server/spanner/standard_client.go @@ -34,6 +34,8 @@ import ( "github.com/datacommonsorg/mixer/internal/util" "golang.org/x/sync/errgroup" "google.golang.org/api/iterator" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" ) @@ -72,6 +74,35 @@ const ( maxConcurrentFilteredSVGGoroutines = 10 ) +// standardSpannerClient encapsulates the Spanner client that directly interacts with the Spanner database. +type standardSpannerClient struct { + exec *SpannerConnector +} + +// newStandardSpannerClient creates a new standardSpannerClient. +func newStandardSpannerClient(exec *SpannerConnector) *standardSpannerClient { + return &standardSpannerClient{exec: exec} +} + +func (sc *standardSpannerClient) Id() string { + return sc.exec.Id() +} + +// Start starts the background goroutine to periodically fetch the timestamp. +func (sc *standardSpannerClient) Start() { + sc.exec.Start() +} + +// Close closes the Spanner client and stops the background goroutine. +func (sc *standardSpannerClient) Close() { + sc.exec.Close() +} + +// GetSdmxObservations is not supported on the default client. +func (sc *standardSpannerClient) GetSdmxObservations(ctx context.Context, req *pb.SdmxDataQuery) (*pb.SdmxDataResult, error) { + return nil, status.Error(codes.Unimplemented, "SDMX queries are only supported on the normalized schema") +} + // GetNodeProps retrieves node properties from Spanner given a list of IDs and a direction and returns a map. func (sc *standardSpannerClient) GetNodeProps(ctx context.Context, ids []string, out bool) (map[string][]*Property, error) { props := map[string][]*Property{} @@ -423,12 +454,6 @@ func populateSpecialFields(event *pbv1.EventCollection_Event, edge *Edge) { } func populateGeoLocation(event *pbv1.EventCollection_Event, value string) { - // Note: The startLocation value in Spanner is usually a latLong/ DCID (e.g. latLong/577521_-958960). - // We parse it here for performance to avoid an extra database roundtrip. - // - // TODO(task): Revisit this optimization if we encounter valid startLocation values - // that are NOT latLong/ DCIDs but still need to be resolved to points, or if the - // assumption that dcids always contain coordinates is not true. if strings.HasPrefix(value, "latLong/") { parts := strings.Split(strings.TrimPrefix(value, "latLong/"), "_") if len(parts) == 2 { @@ -527,20 +552,7 @@ func (sc *standardSpannerClient) GetEventCollectionDcids(ctx context.Context, pl return res, nil } -type parsedEvent struct { - dcid string - magnitude float64 -} - // parseMagnitudeDcid parses the numeric magnitude value from a DCID string. -// -// Background: -// In the Spanner graph, quantity nodes have a `value` property that is identical to their DCID -// (e.g. `SquareKilometer91.57871`). Since we'd still receive a string with a prefix even after -// another jump, we can bypass the redundant join and parse the numeric value directly from the -// `object_id` of the edge in-memory. -// Ideally the value in Spanner should be stored as just the value and not this awkward string. -// If that happens, we can remove this function and just use the value directly. func parseMagnitudeDcid(magnitudeDcid, unit string) float64 { if magnitudeDcid == "" || unit == "" { return 0.0 @@ -554,6 +566,11 @@ func parseMagnitudeDcid(magnitudeDcid, unit string) float64 { return v } +type parsedEvent struct { + dcid string + magnitude float64 +} + // parseAndSortEvents parses magnitude DCIDs, sorts events by magnitude then DCID alphabetical, and truncates to top 100. func parseAndSortEvents(rows []EventIdWithMagnitudeDcid, eventType string) []string { cfg, hasCfg := EventConfigs[eventType] @@ -706,7 +723,7 @@ func (sc *standardSpannerClient) GetStatVarGroupNode(ctx context.Context, nodes func (sc *standardSpannerClient) GetFilteredStatVarGroupNode(ctx context.Context, nodes []string, constrainedPlaces []string, constrainedImport string, numEntitiesExistence int, includeDefinitions bool) (map[string]*FilteredStatVarGroupNode, error) { response := map[string]*FilteredStatVarGroupNode{} errGroup, errCtx := errgroup.WithContext(ctx) - errGroup.SetLimit(maxConcurrentFilteredSVGGoroutines) // Limit the number of concurrent goroutines to avoid overwhelming Spanner with too many requests. + errGroup.SetLimit(maxConcurrentFilteredSVGGoroutines) type nodeResult struct { node string @@ -851,6 +868,3 @@ func (sc *standardSpannerClient) GetFilteredTopic(ctx context.Context, nodes []s return counts, nil } - - - diff --git a/internal/server/spanner/timestamp_test.go b/internal/server/spanner/timestamp_test.go index 5d6cd7e72..0c84e8964 100644 --- a/internal/server/spanner/timestamp_test.go +++ b/internal/server/spanner/timestamp_test.go @@ -53,7 +53,7 @@ func TestTimestampUpdated(t *testing.T) { startTime := time.Date(2025, time.January, 1, 10, 0, 0, 0, time.UTC) updateDone := make(chan bool, 1) - sc := &SpannerExecutor{ + sc := &SpannerConnector{ ticker: mockTicker, stopCh: make(chan struct{}), } @@ -96,7 +96,7 @@ func TestTimestampUpdateFailure(t *testing.T) { startTime := time.Date(2025, time.January, 1, 10, 0, 0, 0, time.UTC) updateDone := make(chan bool, 1) - sc := &SpannerExecutor{ + sc := &SpannerConnector{ ticker: mockTicker, stopCh: make(chan struct{}), } From 09a025bf030ee45b3503e930858dc05bcdca6482 Mon Sep 17 00:00:00 2001 From: Christie Ellks Date: Fri, 15 May 2026 12:02:39 -0700 Subject: [PATCH 6/8] re-org --- internal/server/spanner/client.go | 114 +++++++-------------- internal/server/spanner/connector.go | 24 ++++- internal/server/spanner/standard_client.go | 14 +++ 3 files changed, 72 insertions(+), 80 deletions(-) diff --git a/internal/server/spanner/client.go b/internal/server/spanner/client.go index b8620ee31..6d12fa260 100644 --- a/internal/server/spanner/client.go +++ b/internal/server/spanner/client.go @@ -20,7 +20,6 @@ import ( "fmt" "log/slog" - "cloud.google.com/go/spanner" pb "github.com/datacommonsorg/mixer/internal/proto" pbv1 "github.com/datacommonsorg/mixer/internal/proto/v1" v2 "github.com/datacommonsorg/mixer/internal/server/v2" @@ -56,16 +55,6 @@ type SpannerClient interface { Close() } - - -func useNormalizedSchema(ctx context.Context) bool { - if md, ok := metadata.FromIncomingContext(ctx); ok { - headers := md.Get(util.XUseNormalizedSchema) - return len(headers) > 0 && headers[0] == "true" - } - return false -} - // selectorClient dispatches calls to either default or normalized client based on request headers. // It serves as the main entry point for the Spanner client, centralizing routing concerns. // @@ -76,16 +65,33 @@ func useNormalizedSchema(ctx context.Context) bool { // internal fallback for general request routing, ensuring that the standard path remains // the explicit default. type selectorClient struct { - SpannerClient // Embeds the default client + SpannerClient // Embeds the standard client as the default client normalized *normalizedSchemaClient } -// logNormalizedInvocation logs that the normalized schema was invoked for a method with custom arguments. -func logNormalizedInvocation(methodName string, args ...any) { - fullArgs := append([]any{"method", methodName}, args...) - slog.Info("Invoking normalized Spanner schema", fullArgs...) +// NewSpannerClient creates a new SpannerClient from the config yaml string and an optional database override. +// It returns a wrapper client that handles request-time schema dispatching. +func NewSpannerClient(ctx context.Context, spannerConfigYaml, databaseOverride string) (SpannerClient, error) { + cfg, err := createSpannerConfig(spannerConfigYaml, databaseOverride) + if err != nil { + return nil, err + } + exec, err := NewSpannerConnector(ctx, cfg) + if err != nil { + return nil, err + } + + defaultClient := newStandardSpannerClient(exec) + normalizedClient := NewNormalizedClient(defaultClient) + + return &selectorClient{ + SpannerClient: defaultClient, + normalized: normalizedClient, + }, nil } + + // GetObservations overrides the embedded client's GetObservations to dispatch based on schema selection. func (s *selectorClient) GetObservations(ctx context.Context, variables []string, entities []string) ([]*Observation, error) { if useNormalizedSchema(ctx) { @@ -132,67 +138,6 @@ func (s *selectorClient) GetSdmxObservations(ctx context.Context, req *pb.SdmxDa return s.normalized.GetSdmxObservations(ctx, req) } -// Force compiler that all methods required by the interface are implemented by clients -var _ SpannerClient = (*standardSpannerClient)(nil) -var _ SpannerClient = (*selectorClient)(nil) - -// NewRawSpannerClient creates a new SpannerClient without the schema selector. -// This is intended for testing and internal use where a direct client is needed. -func NewRawSpannerClient(ctx context.Context, spannerConfigYaml, databaseOverride string) (SpannerClient, error) { - cfg, err := createSpannerConfig(spannerConfigYaml, databaseOverride) - if err != nil { - return nil, fmt.Errorf("failed to create standardSpannerClient: %w", err) - } - client, err := createSpannerClient(ctx, cfg) - if err != nil { - return nil, fmt.Errorf("failed to create standardSpannerClient: %w", err) - } - exec, err := NewSpannerConnector(client) - if err != nil { - return nil, fmt.Errorf("failed to create standardSpannerClient: %w", err) - } - return newStandardSpannerClient(exec), nil -} - -// NewSpannerClient creates a new SpannerClient from the config yaml string and an optional database override. -// It returns a wrapper client that handles request-time schema dispatching. -func NewSpannerClient(ctx context.Context, spannerConfigYaml, databaseOverride string) (SpannerClient, error) { - cfg, err := createSpannerConfig(spannerConfigYaml, databaseOverride) - if err != nil { - return nil, err - } - client, err := createSpannerClient(ctx, cfg) - if err != nil { - return nil, err - } - exec, err := NewSpannerConnector(client) - if err != nil { - return nil, err - } - - defaultClient := newStandardSpannerClient(exec) - normalizedClient := NewNormalizedClient(defaultClient) - - return &selectorClient{ - SpannerClient: defaultClient, - normalized: normalizedClient, - }, nil -} - -// createSpannerClient creates the database name string and initializes the Spanner client. -func createSpannerClient(ctx context.Context, cfg *SpannerConfig) (*spanner.Client, error) { - // Construct the database name string - databaseName := fmt.Sprintf("projects/%s/instances/%s/databases/%s", cfg.Project, cfg.Instance, cfg.Database) - - // Create the Spanner client - client, err := spanner.NewClient(ctx, databaseName) - if err != nil { - return nil, fmt.Errorf("failed to create Spanner client: %w", err) - } - - return client, nil -} - // createSpannerConfig creates the config from the specific yaml string and an optional database override. func createSpannerConfig(spannerConfigYaml, databaseOverride string) (*SpannerConfig, error) { var cfg SpannerConfig @@ -211,4 +156,19 @@ func createSpannerConfig(spannerConfigYaml, databaseOverride string) (*SpannerCo return &cfg, nil } +// useNormalizedSchema checks whether to use the normalized Spanner schema based on request header. +func useNormalizedSchema(ctx context.Context) bool { + if md, ok := metadata.FromIncomingContext(ctx); ok { + headers := md.Get(util.XUseNormalizedSchema) + return len(headers) > 0 && headers[0] == "true" + } + return false +} + +// logNormalizedInvocation logs that the normalized schema was invoked for a method with custom arguments. +func logNormalizedInvocation(methodName string, args ...any) { + fullArgs := append([]any{"method", methodName}, args...) + slog.Info("Invoking normalized Spanner schema", fullArgs...) +} + diff --git a/internal/server/spanner/connector.go b/internal/server/spanner/connector.go index 56274e159..f7b68eec4 100644 --- a/internal/server/spanner/connector.go +++ b/internal/server/spanner/connector.go @@ -28,11 +28,11 @@ import ( "cloud.google.com/go/spanner" "github.com/datacommonsorg/mixer/internal/metrics" "github.com/datacommonsorg/mixer/internal/util" + "google.golang.org/api/iterator" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" - "google.golang.org/api/iterator" ) // SpannerConnector handles the low-level details of connecting to Spanner and executing queries. @@ -48,7 +48,11 @@ type SpannerConnector struct { } // NewSpannerConnector creates a new SpannerConnector. -func NewSpannerConnector(client *spanner.Client) (*SpannerConnector, error) { +func NewSpannerConnector(ctx context.Context, cfg *SpannerConfig) (*SpannerConnector, error) { + client, err := newDBClient(ctx, cfg) + if err != nil { + return nil, err + } se := &SpannerConnector{ client: client, } @@ -57,7 +61,7 @@ func NewSpannerConnector(client *spanner.Client) (*SpannerConnector, error) { se.ticker = NewTimestampTicker() se.stopCh = make(chan struct{}) se.updateTimestamp = se.fetchAndUpdateTimestamp - if err := se.updateTimestamp(context.Background()); err != nil { + if err := se.updateTimestamp(ctx); err != nil { slog.Error("Error initializing Spanner staleness timestamp", "error", err.Error()) return nil, err } @@ -369,3 +373,17 @@ func getSchemaName(ctx context.Context) string { } return "Legacy" } + +// newDBClient creates the database name string and initializes the Spanner client. +func newDBClient(ctx context.Context, cfg *SpannerConfig) (*spanner.Client, error) { + // Construct the database name string + databaseName := fmt.Sprintf("projects/%s/instances/%s/databases/%s", cfg.Project, cfg.Instance, cfg.Database) + + // Create the Spanner client + client, err := spanner.NewClient(ctx, databaseName) + if err != nil { + return nil, fmt.Errorf("failed to create Spanner client: %w", err) + } + + return client, nil +} diff --git a/internal/server/spanner/standard_client.go b/internal/server/spanner/standard_client.go index 121328f34..33666bda9 100644 --- a/internal/server/spanner/standard_client.go +++ b/internal/server/spanner/standard_client.go @@ -868,3 +868,17 @@ func (sc *standardSpannerClient) GetFilteredTopic(ctx context.Context, nodes []s return counts, nil } + +// NewStandardClient creates a new SpannerClient without the schema selector. +// This is intended for testing and internal use where a direct client is needed. +func NewStandardClient(ctx context.Context, spannerConfigYaml, databaseOverride string) (SpannerClient, error) { + cfg, err := createSpannerConfig(spannerConfigYaml, databaseOverride) + if err != nil { + return nil, fmt.Errorf("failed to create standardSpannerClient: %w", err) + } + exec, err := NewSpannerConnector(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("failed to create standardSpannerClient: %w", err) + } + return newStandardSpannerClient(exec), nil +} From 1fef570c6689cd795b8523b05ee5d485a616b133 Mon Sep 17 00:00:00 2001 From: Christie Ellks Date: Fri, 15 May 2026 21:20:44 -0700 Subject: [PATCH 7/8] refine tests --- internal/server/spanner/client.go | 2 +- internal/server/spanner/client_test.go | 151 +++++++++++++++ internal/server/spanner/datasource_test.go | 180 ++++++++++++++++++ .../golden/datasource_normalized_test.go | 131 ------------- .../server/spanner/golden/datasource_test.go | 163 ---------------- 5 files changed, 332 insertions(+), 295 deletions(-) create mode 100644 internal/server/spanner/client_test.go create mode 100644 internal/server/spanner/datasource_test.go delete mode 100644 internal/server/spanner/golden/datasource_normalized_test.go diff --git a/internal/server/spanner/client.go b/internal/server/spanner/client.go index 6d12fa260..3fbe3a2f7 100644 --- a/internal/server/spanner/client.go +++ b/internal/server/spanner/client.go @@ -66,7 +66,7 @@ type SpannerClient interface { // the explicit default. type selectorClient struct { SpannerClient // Embeds the standard client as the default client - normalized *normalizedSchemaClient + normalized SpannerClient } // NewSpannerClient creates a new SpannerClient from the config yaml string and an optional database override. diff --git a/internal/server/spanner/client_test.go b/internal/server/spanner/client_test.go new file mode 100644 index 000000000..1820bb477 --- /dev/null +++ b/internal/server/spanner/client_test.go @@ -0,0 +1,151 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spanner + +import ( + "bytes" + "context" + "log/slog" + "slices" + "strings" + "testing" + + v2 "github.com/datacommonsorg/mixer/internal/server/v2" + "github.com/datacommonsorg/mixer/internal/util" + "google.golang.org/grpc/metadata" +) + +// mockSpannerClient embeds SpannerClient to reduce boilerplate. +type mockSpannerClient struct { + SpannerClient + resolveByIDRes map[string][]string + getNodeEdgesRes map[string][]*Edge + checkVariableExistenceRes [][]string + filterNodesByTypeRes map[string][]string + getObservationsRes []*Observation + getObservationsContainedInPlaceRes []*Observation +} + +func (m *mockSpannerClient) GetNodeProps(ctx context.Context, ids []string, out bool) (map[string][]*Property, error) { + return nil, nil +} +func (m *mockSpannerClient) GetNodeEdgesByID(ctx context.Context, ids []string, arc *v2.Arc, pageSize, offset int) (map[string][]*Edge, error) { + return m.getNodeEdgesRes, nil +} +func (m *mockSpannerClient) GetObservations(ctx context.Context, variables []string, entities []string) ([]*Observation, error) { + return m.getObservationsRes, nil +} +func (m *mockSpannerClient) CheckVariableExistence(ctx context.Context, variables []string, entities []string) ([][]string, error) { + return m.checkVariableExistenceRes, nil +} +func (m *mockSpannerClient) GetObservationsContainedInPlace(ctx context.Context, variables []string, containedInPlace *v2.ContainedInPlace) ([]*Observation, error) { + return m.getObservationsContainedInPlaceRes, nil +} +func (m *mockSpannerClient) ResolveByID(ctx context.Context, nodes []string, in, out string) (map[string][]string, error) { + return m.resolveByIDRes, nil +} +func (m *mockSpannerClient) FilterNodesByTypes(ctx context.Context, nodes []string, typeFilters []string) (map[string][]string, error) { + res := map[string][]string{} + for _, typeFilter := range typeFilters { + allowedNodes := m.filterNodesByTypeRes[typeFilter] + for _, node := range nodes { + if slices.Contains(allowedNodes, node) { + res[node] = append(res[node], typeFilter) + } + } + } + return res, nil +} +func (m *mockSpannerClient) Id() string { return "mock" } + +func TestSelectorClient_GetObservations(t *testing.T) { + mockDefault := &mockSpannerClient{ + getObservationsRes: []*Observation{{VariableMeasured: "var_default"}}, + } + mockNormalized := &mockSpannerClient{ + getObservationsRes: []*Observation{{VariableMeasured: "var_normalized"}}, + } + + client := &selectorClient{ + SpannerClient: mockDefault, + normalized: mockNormalized, + } + + // Test Case 1: Default path (no header) + ctx := context.Background() + got, err := client.GetObservations(ctx, []string{"var1"}, []string{"entity1"}) + if err != nil { + t.Fatalf("GetObservations failed: %v", err) + } + if len(got) != 1 || got[0].VariableMeasured != "var_default" { + t.Errorf("Expected var_default, got %v", got) + } + + // Test Case 2: Normalized path (header set to true) + ctx = metadata.NewIncomingContext(ctx, metadata.Pairs(util.XUseNormalizedSchema, "true")) + + // Capture slog output + var buf bytes.Buffer + handler := slog.NewTextHandler(&buf, nil) + logger := slog.New(handler) + originalLogger := slog.Default() + slog.SetDefault(logger) + defer slog.SetDefault(originalLogger) + + got, err = client.GetObservations(ctx, []string{"var1"}, []string{"entity1"}) + if err != nil { + t.Fatalf("GetObservations failed: %v", err) + } + if len(got) != 1 || got[0].VariableMeasured != "var_normalized" { + t.Errorf("Expected var_normalized, got %v", got) + } + if !strings.Contains(buf.String(), "Invoking normalized Spanner schema") { + t.Errorf("Expected log not found. Logs: %s", buf.String()) + } +} + +func TestSelectorClient_CheckVariableExistence(t *testing.T) { + mockDefault := &mockSpannerClient{ + checkVariableExistenceRes: [][]string{{"var_default"}}, + } + mockNormalized := &mockSpannerClient{ + checkVariableExistenceRes: [][]string{{"var_normalized"}}, + } + + client := &selectorClient{ + SpannerClient: mockDefault, + normalized: mockNormalized, + } + + // Test Case 1: Default path + ctx := context.Background() + got, err := client.CheckVariableExistence(ctx, []string{"var1"}, []string{"entity1"}) + if err != nil { + t.Fatalf("CheckVariableExistence failed: %v", err) + } + if len(got) != 1 || got[0][0] != "var_default" { + t.Errorf("Expected var_default, got %v", got) + } + + // Test Case 2: Normalized path + ctx = metadata.NewIncomingContext(ctx, metadata.Pairs(util.XUseNormalizedSchema, "true")) + got, err = client.CheckVariableExistence(ctx, []string{"var1"}, []string{"entity1"}) + if err != nil { + t.Fatalf("CheckVariableExistence failed: %v", err) + } + if len(got) != 1 || got[0][0] != "var_normalized" { + t.Errorf("Expected var_normalized, got %v", got) + } +} diff --git a/internal/server/spanner/datasource_test.go b/internal/server/spanner/datasource_test.go new file mode 100644 index 000000000..6667de8f8 --- /dev/null +++ b/internal/server/spanner/datasource_test.go @@ -0,0 +1,180 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spanner + +import ( + "context" + "testing" + + pbv2 "github.com/datacommonsorg/mixer/internal/proto/v2" + "github.com/datacommonsorg/mixer/internal/server/dispatcher" +) + + + +// Tests moved from golden/datasource_test.go + +func TestSpannerObservation_ExpressionExpansion(t *testing.T) { + ctx := context.Background() + + client := &mockSpannerClient{ + getNodeEdgesRes: map[string][]*Edge{ + "geoId/06": { + {Value: "geoId/06002", Predicate: "linkedContainedInPlace"}, + }, + }, + getObservationsRes: []*Observation{ + { + VariableMeasured: "Count_Person", + ObservationAbout: "geoId/06001", + Observations: []*DateValue{ + {Date: "2020", Value: "12345"}, + }, + }, + { + VariableMeasured: "Count_Person", + ObservationAbout: "geoId/06002", + Observations: []*DateValue{ + {Date: "2020", Value: "67890"}, + }, + }, + }, + } + + ds := NewSpannerDataSource(client, nil, nil, false) + + req := &pbv2.ObservationRequest{ + Variable: &pbv2.DcidOrExpression{Dcids: []string{"Count_Person"}}, + Entity: &pbv2.DcidOrExpression{Expression: "geoId/06<-containedInPlace+{typeOf:County}"}, + Select: []string{"variable", "entity", "value"}, + } + + remoteDCIDs := []string{"geoId/06001"} + ctxWithRemote := context.WithValue(ctx, dispatcher.RelationExpressionExpandedEntities, remoteDCIDs) + + resp, err := ds.Observation(ctxWithRemote, req) + if err != nil { + t.Fatalf("Observation failed: %v", err) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + return + } + + byVariable := resp.ByVariable + if byVariable == nil { + t.Fatal("Expected ByVariable to be populated") + } + countPerson, ok := byVariable["Count_Person"] + if !ok { + t.Fatal("Expected Count_Person in response") + } + byEntity := countPerson.ByEntity + if byEntity == nil { + t.Fatal("Expected ByEntity to be populated") + } + + if _, ok := byEntity["geoId/06001"]; !ok { + t.Errorf("Expected data for geoId/06001 (remote place)") + } + if _, ok := byEntity["geoId/06002"]; !ok { + t.Errorf("Expected data for geoId/06002 (local place)") + } +} + +func TestSpannerObservation_ExpressionExpansion_Fallback(t *testing.T) { + ctx := context.Background() + + client := &mockSpannerClient{ + getObservationsContainedInPlaceRes: []*Observation{ + { + VariableMeasured: "Count_Person", + ObservationAbout: "geoId/06002", + Observations: []*DateValue{ + {Date: "2020", Value: "67890"}, + }, + }, + }, + } + + ds := NewSpannerDataSource(client, nil, nil, false) + + req := &pbv2.ObservationRequest{ + Variable: &pbv2.DcidOrExpression{Dcids: []string{"Count_Person"}}, + Entity: &pbv2.DcidOrExpression{Expression: "geoId/06<-containedInPlace+{typeOf:County}"}, + Select: []string{"variable", "entity", "value"}, + } + + resp, err := ds.Observation(ctx, req) + if err != nil { + t.Fatalf("Observation failed: %v", err) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + return + } + + byVariable := resp.ByVariable + countPerson := byVariable["Count_Person"] + byEntity := countPerson.ByEntity + + if _, ok := byEntity["geoId/06002"]; !ok { + t.Errorf("Expected data for geoId/06002 (local place)") + } +} + +func TestSpannerObservation_NoExpression(t *testing.T) { + ctx := context.Background() + + client := &mockSpannerClient{ + getObservationsRes: []*Observation{ + { + VariableMeasured: "Count_Person", + ObservationAbout: "geoId/06", + Observations: []*DateValue{ + {Date: "2020", Value: "12345"}, + }, + }, + }, + } + + ds := NewSpannerDataSource(client, nil, nil, false) + + req := &pbv2.ObservationRequest{ + Variable: &pbv2.DcidOrExpression{Dcids: []string{"Count_Person"}}, + Entity: &pbv2.DcidOrExpression{Dcids: []string{"geoId/06"}}, + Select: []string{"variable", "entity", "value"}, + } + + resp, err := ds.Observation(ctx, req) + if err != nil { + t.Fatalf("Observation failed: %v", err) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + return + } + + byVariable := resp.ByVariable + countPerson := byVariable["Count_Person"] + byEntity := countPerson.ByEntity + + if _, ok := byEntity["geoId/06"]; !ok { + t.Errorf("Expected data for geoId/06") + } +} diff --git a/internal/server/spanner/golden/datasource_normalized_test.go b/internal/server/spanner/golden/datasource_normalized_test.go deleted file mode 100644 index 519368eb1..000000000 --- a/internal/server/spanner/golden/datasource_normalized_test.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2026 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package golden - -import ( - "bytes" - "context" - "log/slog" - "path" - "runtime" - "strings" - "testing" - - pbv2 "github.com/datacommonsorg/mixer/internal/proto/v2" - "github.com/datacommonsorg/mixer/internal/server/spanner" - "github.com/datacommonsorg/mixer/internal/util" - "github.com/datacommonsorg/mixer/test" - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/metadata" - "google.golang.org/protobuf/testing/protocmp" -) - -// Note: Tests in this file are executed sequentially within a single top-level test. -// They modify the global slog.Default() to capture and assert on logs, which makes -// them unsafe for parallel execution. Please avoid adding too many test cases here -// to prevent slowing down the test suite, as each case runs sequentially and involves -// expensive database calls. -func TestObservation_SchemaSelector(t *testing.T) { - client := test.NewSchemaSelectorSpannerClient(t) - ds := spanner.NewSpannerDataSource(client, nil, nil, false) - - _, filename, _, _ := runtime.Caller(0) - goldenDir := path.Join(path.Dir(filename), "query") - - testCases := []struct { - desc string - req *pbv2.ObservationRequest - useNormalizedHeader bool - goldenFile string - }{ - { - desc: "Default path - basic query", - req: &pbv2.ObservationRequest{ - Variable: &pbv2.DcidOrExpression{ - Dcids: []string{"AirPollutant_Cancer_Risk"}, - }, - Entity: &pbv2.DcidOrExpression{ - Dcids: []string{"geoId/01001"}, - }, - Select: []string{"variable", "entity", "date", "value"}, - }, - useNormalizedHeader: false, - goldenFile: "default_obs_basic.json", - }, - { - desc: "Normalized path - basic query", - req: &pbv2.ObservationRequest{ - Variable: &pbv2.DcidOrExpression{ - Dcids: []string{"AirPollutant_Cancer_Risk"}, - }, - Entity: &pbv2.DcidOrExpression{ - Dcids: []string{"geoId/01001"}, - }, - Select: []string{"variable", "entity", "date", "value"}, - }, - useNormalizedHeader: true, - goldenFile: "normalized_obs_basic.json", - }, - } - - for _, c := range testCases { - t.Run(c.desc, func(t *testing.T) { - // Capture slog output - var buf bytes.Buffer - handler := slog.NewTextHandler(&buf, nil) - logger := slog.New(handler) - originalLogger := slog.Default() - slog.SetDefault(logger) - defer slog.SetDefault(originalLogger) - - ctx := context.Background() - if c.useNormalizedHeader { - ctx = metadata.NewIncomingContext(ctx, metadata.Pairs(util.XUseNormalizedSchema, "true")) - } - - got, err := ds.Observation(ctx, c.req) - if err != nil { - t.Fatalf("Observation error: %v", err) - } - - // Log assertions - logStr := buf.String() - hasLog := strings.Contains(logStr, "Invoking normalized Spanner schema") - if c.useNormalizedHeader && !hasLog { - t.Errorf("Expected log message 'Invoking normalized Spanner schema' not found in logs: %s", logStr) - } - if !c.useNormalizedHeader && hasLog { - t.Errorf("Unexpected log message 'Invoking normalized Spanner schema' found in logs: %s", logStr) - } - - if test.GenerateGolden { - test.UpdateProtoGolden(got, goldenDir, c.goldenFile) - return - } - - var want pbv2.ObservationResponse - if err = test.ReadJSON(goldenDir, c.goldenFile, &want); err != nil { - t.Fatalf("ReadJSON error (%v): %v", c.goldenFile, err) - } - - cmpOpts := cmp.Options{ - protocmp.Transform(), - } - if diff := cmp.Diff(got, &want, cmpOpts); diff != "" { - t.Errorf("%s: %v payload mismatch:\n%v", c.desc, c.goldenFile, diff) - } - }) - } -} diff --git a/internal/server/spanner/golden/datasource_test.go b/internal/server/spanner/golden/datasource_test.go index 0452198b8..1b4cc47b2 100644 --- a/internal/server/spanner/golden/datasource_test.go +++ b/internal/server/spanner/golden/datasource_test.go @@ -27,7 +27,6 @@ import ( pbv1 "github.com/datacommonsorg/mixer/internal/proto/v1" pbv2 "github.com/datacommonsorg/mixer/internal/proto/v2" "github.com/datacommonsorg/mixer/internal/server/datasources" - "github.com/datacommonsorg/mixer/internal/server/dispatcher" "github.com/datacommonsorg/mixer/internal/server/spanner" v2 "github.com/datacommonsorg/mixer/internal/server/v2" "github.com/datacommonsorg/mixer/internal/store/files" @@ -589,166 +588,4 @@ func TestBulkVariableGroupInfo_Filtering(t *testing.T) { } } -// TODO: Move unit tests to a separate test file since this file is meant for golden tests. -func TestSpannerObservation_ExpressionExpansion(t *testing.T) { - ctx := context.Background() - - // Mock Spanner client - client := &mockSpannerClient{ - // Mock GetNodeEdgesByID to return local child places - getNodeEdgesRes: map[string][]*spanner.Edge{ - "geoId/06": { - {Value: "geoId/06002", Predicate: "linkedContainedInPlace"}, - }, - }, - // Mock GetObservations to return observations for merged list - getObservationsRes: []*spanner.Observation{ - { - VariableMeasured: "Count_Person", - ObservationAbout: "geoId/06001", // Remote place - Observations: []*spanner.DateValue{ - {Date: "2020", Value: "12345"}, - }, - }, - { - VariableMeasured: "Count_Person", - ObservationAbout: "geoId/06002", // Local place - Observations: []*spanner.DateValue{ - {Date: "2020", Value: "67890"}, - }, - }, - }, - } - - ds := spanner.NewSpannerDataSource(client, nil, nil, false) - - // Test Case 1: Expression with Remote Data in Context - req := &pbv2.ObservationRequest{ - Variable: &pbv2.DcidOrExpression{Dcids: []string{"Count_Person"}}, - Entity: &pbv2.DcidOrExpression{Expression: "geoId/06<-containedInPlace+{typeOf:County}"}, - Select: []string{"variable", "entity", "value"}, - } - - // Add remote DCIDs to context - remoteDCIDs := []string{"geoId/06001"} - ctxWithRemote := context.WithValue(ctx, dispatcher.RelationExpressionExpandedEntities, remoteDCIDs) - - resp, err := ds.Observation(ctxWithRemote, req) - if err != nil { - t.Fatalf("Observation failed: %v", err) - } - if resp == nil { - t.Fatal("Expected non-nil response") - return - } - - // Verify that we have data for both geoId/06001 and geoId/06002 - byVariable := resp.ByVariable - if byVariable == nil { - t.Fatal("Expected ByVariable to be populated") - } - countPerson, ok := byVariable["Count_Person"] - if !ok { - t.Fatal("Expected Count_Person in response") - } - byEntity := countPerson.ByEntity - if byEntity == nil { - t.Fatal("Expected ByEntity to be populated") - } - - if _, ok := byEntity["geoId/06001"]; !ok { - t.Errorf("Expected data for geoId/06001 (remote place)") - } - if _, ok := byEntity["geoId/06002"]; !ok { - t.Errorf("Expected data for geoId/06002 (local place)") - } -} - -func TestSpannerObservation_ExpressionExpansion_Fallback(t *testing.T) { - ctx := context.Background() - - // Mock Spanner client - client := &mockSpannerClient{ - // Mock GetObservationsContainedInPlace to return observations - getObservationsContainedInPlaceRes: []*spanner.Observation{ - { - VariableMeasured: "Count_Person", - ObservationAbout: "geoId/06002", // Local place - Observations: []*spanner.DateValue{ - {Date: "2020", Value: "67890"}, - }, - }, - }, - } - - ds := spanner.NewSpannerDataSource(client, nil, nil, false) - - req := &pbv2.ObservationRequest{ - Variable: &pbv2.DcidOrExpression{Dcids: []string{"Count_Person"}}, - Entity: &pbv2.DcidOrExpression{Expression: "geoId/06<-containedInPlace+{typeOf:County}"}, - Select: []string{"variable", "entity", "value"}, - } - - resp, err := ds.Observation(ctx, req) - if err != nil { - t.Fatalf("Observation failed: %v", err) - } - - if resp == nil { - t.Fatal("Expected non-nil response") - return - } - - byVariable := resp.ByVariable - countPerson := byVariable["Count_Person"] - byEntity := countPerson.ByEntity - - if _, ok := byEntity["geoId/06002"]; !ok { - t.Errorf("Expected data for geoId/06002 (local place)") - } -} - -func TestSpannerObservation_NoExpression(t *testing.T) { - ctx := context.Background() - - // Mock Spanner client - client := &mockSpannerClient{ - // Mock GetObservations to return observations - getObservationsRes: []*spanner.Observation{ - { - VariableMeasured: "Count_Person", - ObservationAbout: "geoId/06", - Observations: []*spanner.DateValue{ - {Date: "2020", Value: "12345"}, - }, - }, - }, - } - - ds := spanner.NewSpannerDataSource(client, nil, nil, false) - - req := &pbv2.ObservationRequest{ - Variable: &pbv2.DcidOrExpression{Dcids: []string{"Count_Person"}}, - Entity: &pbv2.DcidOrExpression{Dcids: []string{"geoId/06"}}, - Select: []string{"variable", "entity", "value"}, - } - - resp, err := ds.Observation(ctx, req) - if err != nil { - t.Fatalf("Observation failed: %v", err) - } - - if resp == nil { - t.Fatal("Expected non-nil response") - return - } - - byVariable := resp.ByVariable - countPerson := byVariable["Count_Person"] - byEntity := countPerson.ByEntity - - if _, ok := byEntity["geoId/06"]; !ok { - t.Errorf("Expected data for geoId/06") - } -} From e18375947d7898970e08ea0ff4f0206e757095ff Mon Sep 17 00:00:00 2001 From: Christie Ellks Date: Fri, 15 May 2026 21:41:54 -0700 Subject: [PATCH 8/8] nit --- test/setup.go | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/test/setup.go b/test/setup.go index f62bf97fd..ebc927110 100644 --- a/test/setup.go +++ b/test/setup.go @@ -404,7 +404,17 @@ func NewSpannerClient() spanner.SpannerClient { } _, filename, _, _ := runtime.Caller(0) spannerGraphInfoYamlPath := path.Join(path.Dir(filename), "../deploy/storage/spanner_graph_info.yaml") - return newSpannerClient(context.Background(), spannerGraphInfoYamlPath) + + spannerGraphInfoYaml, err := os.ReadFile(spannerGraphInfoYamlPath) + if err != nil { + log.Fatalf("Failed to read spanner yaml: %v", err) + } + spannerClient, err := spanner.NewSpannerClient(context.Background(), string(spannerGraphInfoYaml), "") + if err != nil { + log.Fatalf("Failed to create SpannerClient: %v", err) + } + spannerClient.Start() + return spannerClient } // SkipIfNormalizedSchemaDisabled skips the test if ENABLE_SPANNER_NORMALIZED_SCHEMA is not set. @@ -425,22 +435,14 @@ func NewNormalizedSpannerClient(t *testing.T) spanner.SpannerClient { return client } -func newSpannerClient(ctx context.Context, spannerGraphInfoYamlPath string) spanner.SpannerClient { - spannerGraphInfoYaml, err := os.ReadFile(spannerGraphInfoYamlPath) - if err != nil { - log.Fatalf("Failed to read spanner yaml: %v", err) - } - // Use NewSpannerClient to get the full setup with selector. - spannerClient, err := spanner.NewSpannerClient(ctx, string(spannerGraphInfoYaml), "") - if err != nil { - log.Fatalf("Failed to create SpannerClient: %v", err) - } - spannerClient.Start() - return spannerClient -} + // NewSchemaSelectorSpannerClient creates a new test schema selector spanner client. func NewSchemaSelectorSpannerClient(t *testing.T) spanner.SpannerClient { - // NewNormalizedSpannerClient already returns the selector client because newSpannerClient uses NewSpannerClient! - return NewNormalizedSpannerClient(t) + SkipIfNormalizedSchemaDisabled(t) + client := NewSpannerClient() + if client == nil { + t.Skip("Skipping selector tests (Spanner graph not enabled)") + } + return client }