From eec067ad511a11f63c03a2c90bb89aedab2037db Mon Sep 17 00:00:00 2001 From: omirandadev Date: Mon, 16 Feb 2026 15:39:41 -0600 Subject: [PATCH] feat: Change default keyBatchSize and optimize prepared statement caching --- .../feast/onlinestore/cassandraonlinestore.go | 132 +++++++++++++----- 1 file changed, 100 insertions(+), 32 deletions(-) diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index a8c86feafc2..2e2aa9d3db0 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -214,8 +214,8 @@ func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig, readBatchSize = legacyBatchSize log.Warn().Msg("key_batch_size is deprecated, please use read_batch_size instead") } else { - readBatchSize = 100.0 - log.Warn().Msg("read_batch_size not specified, defaulting to batches of size 100") + readBatchSize = 10.0 + log.Warn().Msg("read_batch_size not specified, defaulting to batches of size 10") } } cassandraConfig.readBatchSize = int(readBatchSize.(float64)) @@ -342,25 +342,23 @@ func (c *CassandraOnlineStore) getFqTableName(keySpace string, project string, f return fmt.Sprintf(`"%s"."%s"`, keySpace, dbTableName), nil } -func (c *CassandraOnlineStore) getSingleKeyCQLStatement(tableName string, featureNames []string) string { - // this prevents fetching unnecessary features - quotedFeatureNames := make([]string, len(featureNames)) - for i, featureName := range featureNames { - quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName) +func (c *CassandraOnlineStore) getSingleKeyCQLStatement(tableName string, numFeatures int) string { + featurePlaceholders := make([]string, numFeatures) + for i := 0; i < numFeatures; i++ { + featurePlaceholders[i] = "?" } return fmt.Sprintf( `SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" = ? AND "feature_name" IN (%s)`, tableName, - strings.Join(quotedFeatureNames, ","), + strings.Join(featurePlaceholders, ","), ) } -func (c *CassandraOnlineStore) getMultiKeyCQLStatement(tableName string, featureNames []string, nkeys int) string { - // this prevents fetching unnecessary features - quotedFeatureNames := make([]string, len(featureNames)) - for i, featureName := range featureNames { - quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName) +func (c *CassandraOnlineStore) getMultiKeyCQLStatement(tableName string, numFeatures int, nkeys int) string { + featurePlaceholders := make([]string, numFeatures) + for i := 0; i < numFeatures; i++ { + featurePlaceholders[i] = "?" } keyPlaceholders := make([]string, nkeys) @@ -371,7 +369,7 @@ func (c *CassandraOnlineStore) getMultiKeyCQLStatement(tableName string, feature `SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" IN (%s) AND "feature_name" IN (%s)`, tableName, strings.Join(keyPlaceholders, ","), - strings.Join(quotedFeatureNames, ","), + strings.Join(featurePlaceholders, ","), ) } @@ -425,6 +423,15 @@ type BatchJob struct { CQLStatement string } +func buildQueryParams(entityKeys []any, featureNames []string) []any { + params := make([]any, 0, len(entityKeys)+len(featureNames)) + params = append(params, entityKeys...) + for _, fn := range featureNames { + params = append(params, fn) + } + return params +} + func (c *CassandraOnlineStore) OnlineReadV2(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys) if err != nil { @@ -443,38 +450,97 @@ func (c *CassandraOnlineStore) OnlineReadV2(ctx context.Context, entityKeys []*t return nil, err } - var cqlForBatch string - cqlForBatch = c.getMultiKeyCQLStatement(tableName, featureNames, len(serializedEntityKeys)) - - job := BatchJob{ - ViewName: featureViewName, - TableName: tableName, - FeatureNames: featureNames, - EntityKeys: serializedEntityKeys, - CQLStatement: cqlForBatch, + results := make([][]FeatureData, len(entityKeys)) + for i := range results { + results[i] = make([]FeatureData, len(featureNames)) } - results, err := c.executeBatchV2(ctx, job, serializedEntityKeyToIndex, featureNamesToIdx) + batches := c.createBatches(serializedEntityKeys) - if err != nil { + g, ctx := errgroup.WithContext(ctx) + var mu sync.Mutex + + var prevBatchLength int + var cqlStatement string + + for i, batch := range batches { + var cqlForBatch string + if i == 0 || len(batch) != prevBatchLength { + cqlForBatch = c.getMultiKeyCQLStatement(tableName, len(featureNames), len(batch)) + prevBatchLength = len(batch) + cqlStatement = cqlForBatch + } else { + cqlForBatch = cqlStatement + } + + job := BatchJob{ + ViewName: featureViewName, + TableName: tableName, + FeatureNames: featureNames, + EntityKeys: batch, + CQLStatement: cqlForBatch, + } + + g.Go(func() error { + batchResults, err := c.executeBatchV2(ctx, job, featureNamesToIdx) + if err != nil { + return err + } + + mu.Lock() + defer mu.Unlock() + for localIdx, key := range job.EntityKeys { + globalIdx := serializedEntityKeyToIndex[key.(string)] + for featIdx := range featureNames { + results[globalIdx][featIdx] = batchResults[localIdx][featIdx] + } + } + return nil + }) + } + + if err := g.Wait(); err != nil { return nil, err } + for i := range results { + for j, feat := range results[i] { + if feat.Value.Val == nil { + results[i][j] = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featureNames[j], + }, + Value: types.Value{ + Val: &types.Value_NullVal{ + NullVal: types.Null_NULL, + }, + }, + } + } + } + } + return results, nil } func (c *CassandraOnlineStore) executeBatchV2( ctx context.Context, job BatchJob, - serializedEntityKeyToIndex map[string]int, featureNamesToIdx map[string]int, ) ([][]FeatureData, error) { + localKeyToIndex := make(map[string]int, len(job.EntityKeys)) + for i, key := range job.EntityKeys { + localKeyToIndex[key.(string)] = i + } + results := make([][]FeatureData, len(job.EntityKeys)) for i := range results { results[i] = make([]FeatureData, len(job.FeatureNames)) } - iter := c.session.Query(job.CQLStatement, job.EntityKeys...).WithContext(ctx).Iter() + queryParams := buildQueryParams(job.EntityKeys, job.FeatureNames) + iter := c.session.Query(job.CQLStatement, queryParams...).WithContext(ctx).Iter() defer iter.Close() scanner := iter.Scanner() @@ -519,9 +585,10 @@ func (c *CassandraOnlineStore) executeBatchV2( for _, serializedEntityKey := range job.EntityKeys { for _, featName := range job.FeatureNames { keyString := serializedEntityKey.(string) + localIdx := localKeyToIndex[keyString] if featureData, exists := batchFeatures[keyString][featName]; exists { - results[serializedEntityKeyToIndex[keyString]][featureNamesToIdx[featName]] = FeatureData{ + results[localIdx][featureNamesToIdx[featName]] = FeatureData{ Reference: serving.FeatureReferenceV2{ FeatureViewName: featureData.Reference.FeatureViewName, FeatureName: featureData.Reference.FeatureName, @@ -532,8 +599,7 @@ func (c *CassandraOnlineStore) executeBatchV2( }, } } else { - // TODO: return not found status to differentiate between nulls and not found features - results[serializedEntityKeyToIndex[keyString]][featureNamesToIdx[featName]] = FeatureData{ + results[localIdx][featureNamesToIdx[featName]] = FeatureData{ Reference: serving.FeatureReferenceV2{ FeatureViewName: job.ViewName, FeatureName: featName, @@ -592,7 +658,7 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ for i, batch := range batches { var cqlForBatch string if i == 0 || len(batch) != prevBatchLength { - cqlForBatch = c.getMultiKeyCQLStatement(tableName, currentFeatureNames, len(batch)) + cqlForBatch = c.getMultiKeyCQLStatement(tableName, len(currentFeatureNames), len(batch)) prevBatchLength = len(batch) cqlStatement = cqlForBatch } else { @@ -640,7 +706,8 @@ func (c *CassandraOnlineStore) executeBatch( results [][]FeatureData, featureNamesToIdx map[string]int, ) error { - iter := c.session.Query(job.CQLStatement, job.EntityKeys...).WithContext(ctx).Iter() + queryParams := buildQueryParams(job.EntityKeys, job.FeatureNames) + iter := c.session.Query(job.CQLStatement, queryParams...).WithContext(ctx).Iter() defer iter.Close() scanner := iter.Scanner() @@ -1038,3 +1105,4 @@ func (c *CassandraOnlineStore) GetDataModelType() OnlineStoreDataModel { func (c *CassandraOnlineStore) GetReadBatchSize() int { return c.KeyBatchSize } +