diff --git a/.features/pending/sql-memoization-cache.md b/.features/pending/sql-memoization-cache.md new file mode 100644 index 000000000000..92b8fa3213c4 --- /dev/null +++ b/.features/pending/sql-memoization-cache.md @@ -0,0 +1,13 @@ +Description: Add SQL database-backed memoization cache as an alternative to ConfigMaps. +Authors: [droctothorpe](https://github.com/droctothorpe) +Component: General +Issues: 15952 +PRs: 15938 + +Memoization can now store cache entries in a PostgreSQL or MySQL database instead of Kubernetes ConfigMaps. +The SQL backend removes the 1 MB ConfigMap size limit and persists cache entries across cluster restarts. +ConfigMaps remain the default; opt in by adding a `memoization` block to the `workflow-controller-configmap`. +SQL-backed entries are stored in the configured table, which defaults to `cache_entries`. +Each cache entry computes an `expires_at` timestamp at save time from the template's `maxAge` field (default: 30 days). +The default max age can be overridden via the `DEFAULT_MAX_AGE` environment variable on the controller when SQL memoization is enabled. +A periodic garbage collector deletes expired entries whose `expires_at` has elapsed. diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index f0493ee0560f..14914260b8bb 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -251,6 +251,12 @@ jobs: - test: test-corefunctional profile: minimal use-api: false + - test: test-sqldbmemoize + profile: mysql + use-api: false + - test: test-sqldbmemoize + profile: postgres + use-api: false - test: test-functional profile: minimal use-api: false diff --git a/.spelling b/.spelling index 06a13734e8a9..2c4577cb372e 100644 --- a/.spelling +++ b/.spelling @@ -151,6 +151,7 @@ args async auth backend +backends backoff backport backported @@ -175,6 +176,7 @@ entrypoint enum env errored +expires_at expr fibonacci filename diff --git a/config/config.go b/config/config.go index 37b95844b5a6..f0af01e154e7 100644 --- a/config/config.go +++ b/config/config.go @@ -121,6 +121,10 @@ type Config struct { // Synchronization via databases config Synchronization *SyncConfig `json:"synchronization,omitempty"` + // Memoization configures memoization cache storage. When set, cache entries are stored in a + // database instead of ConfigMaps. ConfigMap-based caching remains the default when omitted. + Memoization *MemoizationConfig `json:"memoization,omitempty"` + // ArtifactDrivers lists artifact driver plugins we can use ArtifactDrivers []ArtifactDriver `json:"artifactDrivers,omitempty"` @@ -353,6 +357,17 @@ type SyncConfig struct { SemaphoreLimitCacheSeconds *int64 `json:"semaphoreLimitCacheSeconds,omitempty"` } +// MemoizationConfig contains memoization cache configuration for database-backed storage. +// When configured, cache entries are stored in the specified database table instead of ConfigMaps. +type MemoizationConfig struct { + DBConfig + // TableName is the name of the table to use for memoization cache entries. + // Defaults to "cache_entries" if not set. + TableName string `json:"tableName,omitempty"` + // SkipMigration skips automatic database migration on startup. + SkipMigration bool `json:"skipMigration,omitempty"` +} + // ConnectionPool contains database connection pool settings type ConnectionPool struct { // MaxIdleConns sets the maximum number of idle connections in the pool diff --git a/docs/environment-variables.md b/docs/environment-variables.md index 1fe05d1b311f..dbca7a9c2bb5 100644 --- a/docs/environment-variables.md +++ b/docs/environment-variables.md @@ -27,6 +27,7 @@ This document outlines environment variables that can be used to customize behav | `CACHE_GC_PERIOD` | `time.Duration` | `0s` | How often to perform memoization cache GC, which is disabled by default and can be enabled by providing a non-zero duration. | | `CACHE_GC_AFTER_NOT_HIT_DURATION` | `time.Duration` | `30s` | When a memoization cache has not been hit after this duration, it will be deleted. | | `CRON_SYNC_PERIOD` | `time.Duration` | `10s` | How often to sync cron workflows. | +| `DEFAULT_MAX_AGE` | `string` | `""` (30 days) | Default TTL for SQL-backed memoization cache entries when `memoize.maxAge` is not set on the template. Accepts a Go duration string (e.g. `720h`) or an integer number of seconds. If unset, entries expire after 30 days. This does not affect ConfigMap-backed memoization. | | `DEFAULT_REQUEUE_TIME` | `time.Duration` | `10s` | The re-queue time for the rate limiter of the workflow queue. | | `DISABLE_MAX_RECURSION` | `bool` | `false` | Set to true to disable the recursion preventer, which will stop a workflow running which has called into a child template 100 times | | `EXPRESSION_TEMPLATES` | `bool` | `true` | Escape hatch to disable expression templates. | @@ -40,6 +41,7 @@ This document outlines environment variables that can be used to customize behav | `LEADER_ELECTION_RENEW_DEADLINE` | `time.Duration` | `10s` | The duration that the acting master will retry refreshing leadership before giving up. | | `LEADER_ELECTION_RETRY_PERIOD` | `time.Duration` | `5s` | The duration that the leader election clients should wait between tries of actions. | | `MAX_OPERATION_TIME` | `time.Duration` | `30s` | The maximum time a workflow operation is allowed to run for before re-queuing the workflow onto the work queue. | +| `MEMO_CACHE_GC_PERIOD` | `time.Duration` | `24h` | How often the SQL-backed memoization cache garbage collector runs to prune entries that have exceeded their TTL. | | `OFFLOAD_NODE_STATUS_TTL` | `time.Duration` | `5m` | The TTL to delete the offloaded node status. Currently only used for testing. | | `OPERATION_DURATION_METRIC_BUCKET_COUNT` | `int` | `6` | The number of buckets to collect the metric for the operation duration. | | `POD_NAMES` | `string` | `v2` | Whether to have pod names contain the template name (v2) or be the node id (v1) - should be set the same for Argo Server. | diff --git a/docs/memoization.md b/docs/memoization.md index 02d6ee4dbb92..7c0c16f31f6c 100644 --- a/docs/memoization.md +++ b/docs/memoization.md @@ -14,15 +14,72 @@ If you are using workflows prior to version 3.5 you should look at the [work avo In version 3.5 or later all steps can be memoized, whether or not they have outputs. -## Cache Method +## Cache Backends -Currently, the cached data is stored in config-maps. +Argo Workflows supports two backends for storing memoization cache entries: + +### ConfigMap (default) + +By default, cached data is stored in Kubernetes ConfigMaps. This allows you to easily manipulate cache entries manually through `kubectl` and the Kubernetes API without having to go through Argo. -All cache config-maps must have the label `workflows.argoproj.io/configmap-type: Cache` to be used as a cache. This prevents accidental access to other important config-maps in the system +All cache ConfigMaps must have the label `workflows.argoproj.io/configmap-type: Cache` to be used as a cache. This prevents accidental access to other important ConfigMaps in the system. + +### SQL Database + +> v4.0 and after + +Alternatively, cache entries can be stored in a PostgreSQL or MySQL database. This is recommended for production use — it has no size limits, supports long-term persistence, and includes automatic garbage collection. + +To enable SQL-backed memoization, add a `memoization` block to the `workflow-controller-configmap`: + +```yaml +apiVersion: v1 +kind: ConfigMap +metadata: + name: workflow-controller-configmap + namespace: argo +data: + memoization: | + tableName: cache_entries + postgresql: + host: postgres + port: 5432 + database: postgres + userNameSecret: + name: argo-postgres-config + key: username + passwordSecret: + name: argo-postgres-config + key: password +``` + +SQL-backed memoization stores entries in the configured table. Set `memoization.tableName` to override the default table name; if omitted, it defaults to `cache_entries`. +The database connection settings remain under `postgresql` or `mysql`. + +Each cache entry stores its expiry time when it is written, derived from the template's `maxAge` field. If `maxAge` is not specified on the template, it defaults to 30 days (2592000 seconds). This default can be overridden by setting the `DEFAULT_MAX_AGE` environment variable on the workflow controller for SQL-backed memoization (accepts Go duration strings like `720h` or integer seconds like `2592000`). + +The garbage collector periodically deletes expired entries. The GC period defaults to 24 hours and can be configured via the `MEMO_CACHE_GC_PERIOD` environment variable. + +MySQL is also supported: + +```yaml + memoization: | + tableName: cache_entries + mysql: + host: mysql + port: 3306 + database: argo + userNameSecret: + name: argo-mysql-config + key: username + passwordSecret: + name: argo-mysql-config + key: password +``` ## Using Memoization -Memoization is set at the template level. You must specify a `key`, which can be static strings but more often depend on inputs. +Memoization is configured at the template level via the `memoize` field. You must specify a `key`, which can be static strings but more often depend on inputs. You must also specify a name for the `config-map` cache. Optionally you can set a `maxAge` in seconds or hours (e.g. `180s`, `24h`) to define how long should it be considered valid. If an entry is older than the `maxAge`, it will be ignored. @@ -43,18 +100,27 @@ spec: name: print-message-cache ``` +### Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `key` | Yes | The cache lookup key. | +| `cache` | Yes | Specifies the cache storage. When using the ConfigMap backend, a ConfigMap is created. When using the SQL backend, `cache.configMap.name` acts as a logical group name in the database — no ConfigMap is created. | +| `maxAge` | No | Maximum age of a cache entry (e.g. `"180s"`, `"24h"`). Entries older than this are treated as misses at lookup time. When omitted for SQL-backed memoization, it defaults to 30 days or the controller's `DEFAULT_MAX_AGE` setting. | + [Find a simple example for memoization here](https://github.com/argoproj/argo-workflows/blob/main/examples/memoize-simple.yaml). !!! Note - In order to use memoization it is necessary to add the verbs `create` and `update` to the `configmaps` resource for the appropriate (cluster) roles. In the case of a cluster install the `argo-cluster-role` cluster role should be updated, whilst for a namespace install the `argo-role` role should be updated. + To use memoization with the ConfigMap backend, add the verbs `create` and `update` to the `configmaps` resource for the appropriate (cluster) roles. For a cluster install, update the `argo-cluster-role` cluster role; for a namespace install, update the `argo-role` role. This is not required when using the SQL database backend. ## FAQ 1. If you see errors like `error creating cache entry: ConfigMap \"reuse-task\" is invalid: []: Too long: must have at most 1048576 characters`, this is due to [the 1MB limit placed on the size of `ConfigMap`](https://github.com/kubernetes/kubernetes/issues/19781). Here are a couple of ways that might help resolve this: - * Delete the existing `ConfigMap` cache or switch to use a different cache. - * Reduce the size of the output parameters for the nodes that are being memoized. - * Split your cache into different memoization keys and cache names so that each cache entry is small. + - Delete the existing `ConfigMap` cache or switch to use a different cache. + - Reduce the size of the output parameters for the nodes that are being memoized. + - Split your cache into different memoization keys and cache names so that each cache entry is small. + - Switch to the SQL database backend which has no size limit. 1. My step isn't getting memoized, why not? If you are running workflows <3.5 ensure that you have specified at least one output on the step. diff --git a/docs/workflow-controller-configmap.md b/docs/workflow-controller-configmap.md index 4fe1a47456ca..e56ac256abdb 100644 --- a/docs/workflow-controller-configmap.md +++ b/docs/workflow-controller-configmap.md @@ -96,6 +96,7 @@ Config contains the root of the configuration settings for the workflow controll | `NavColor` | `string` | NavColor is an ui navigation bar background color | | `SSO` | [`SSOConfig`](#ssoconfig) | SSO in settings for single-sign on | | `Synchronization` | [`SyncConfig`](#syncconfig) | Synchronization via databases config | +| `Memoization` | [`MemoizationConfig`](#memoizationconfig) | Memoization configures memoization cache storage. When set, cache entries are stored in a database instead of ConfigMaps. ConfigMap-based caching remains the default when omitted. | | `ArtifactDrivers` | `Array<`[`ArtifactDriver`](#artifactdriver)`>` | ArtifactDrivers lists artifact driver plugins we can use | | `FailedPodRestart` | [`FailedPodRestartConfig`](#failedpodrestartconfig) | FailedPodRestart configures automatic restart of pods that fail before entering Running state (e.g., due to Eviction, DiskPressure, Preemption). This allows recovery from transient infrastructure issues without requiring a retryStrategy on templates. | @@ -369,6 +370,21 @@ SyncConfig contains synchronization configuration for database locks (semaphores | `InactiveControllerSeconds` | `int` | InactiveControllerSeconds specifies when to consider a controller dead, if not set, the default value is 300 seconds | | `SemaphoreLimitCacheSeconds` | `int64` | SemaphoreLimitCacheSeconds specifies the duration in seconds before the workflow controller will re-fetch the limit for a semaphore from its associated data source. Defaults to 0 seconds (re-fetch every time the semaphore is checked). | +## MemoizationConfig + +MemoizationConfig contains memoization cache configuration for database-backed storage. When configured, cache entries are stored in the specified database table instead of ConfigMaps. + +### Fields + +| Field Name | Field Type | Description | +|---------------------|-------------------------------------------|------------------------------------------------------------------------------------------------------------------| +| `PostgreSQL` | [`PostgreSQLConfig`](#postgresqlconfig) | PostgreSQL configuration for PostgreSQL database, don't use MySQL at the same time | +| `MySQL` | [`MySQLConfig`](#mysqlconfig) | MySQL configuration for MySQL database, don't use PostgreSQL at the same time | +| `ConnectionPool` | [`ConnectionPool`](#connectionpool) | Pooled connection settings for all types of database connections | +| `DBReconnectConfig` | [`DBReconnectConfig`](#dbreconnectconfig) | DBReconnectConfig are configuration options for database retries and reconnections | +| `TableName` | `string` | TableName is the name of the table to use for memoization cache entries. Defaults to "cache_entries" if not set. | +| `SkipMigration` | `bool` | SkipMigration skips automatic database migration on startup. | + ## ArtifactDriver ArtifactDriver is a plugin for an artifact driver diff --git a/manifests/components/mysql/overlays/workflow-controller-configmap.yaml b/manifests/components/mysql/overlays/workflow-controller-configmap.yaml index 83b9ef02f388..9c4086616a3e 100644 --- a/manifests/components/mysql/overlays/workflow-controller-configmap.yaml +++ b/manifests/components/mysql/overlays/workflow-controller-configmap.yaml @@ -44,6 +44,18 @@ data: passwordSecret: name: argo-mysql-config key: password + memoization: | + tableName: cache_entries + mysql: + host: mysql + port: 3306 + database: argo + userNameSecret: + name: argo-mysql-config + key: username + passwordSecret: + name: argo-mysql-config + key: password retentionPolicy: | completed: 10 failed: 3 diff --git a/manifests/components/postgres/overlays/workflow-controller-configmap.yaml b/manifests/components/postgres/overlays/workflow-controller-configmap.yaml index a903daff5a7e..39572e2a9dda 100644 --- a/manifests/components/postgres/overlays/workflow-controller-configmap.yaml +++ b/manifests/components/postgres/overlays/workflow-controller-configmap.yaml @@ -43,6 +43,18 @@ data: passwordSecret: name: argo-postgres-config key: password + memoization: | + tableName: cache_entries + postgresql: + host: postgres + port: 5432 + database: postgres + userNameSecret: + name: argo-postgres-config + key: username + passwordSecret: + name: argo-postgres-config + key: password retentionPolicy: | completed: 10 failed: 3 diff --git a/manifests/quick-start-mysql.yaml b/manifests/quick-start-mysql.yaml index 97539d2bf7f1..36d909e19c8f 100644 --- a/manifests/quick-start-mysql.yaml +++ b/manifests/quick-start-mysql.yaml @@ -183471,6 +183471,18 @@ data: - name: Completed Workflows scope: workflow-list url: http://workflows?label=workflows.argoproj.io/completed=true + memoization: | + tableName: cache_entries + mysql: + host: mysql + port: 3306 + database: argo + userNameSecret: + name: argo-mysql-config + key: username + passwordSecret: + name: argo-mysql-config + key: password metricsConfig: | enabled: true path: /metrics diff --git a/manifests/quick-start-postgres.yaml b/manifests/quick-start-postgres.yaml index cd87bd93d17a..a7f8195128e4 100644 --- a/manifests/quick-start-postgres.yaml +++ b/manifests/quick-start-postgres.yaml @@ -183471,6 +183471,18 @@ data: - name: Completed Workflows scope: workflow-list url: http://workflows?label=workflows.argoproj.io/completed=true + memoization: | + tableName: cache_entries + postgresql: + host: postgres + port: 5432 + database: postgres + userNameSecret: + name: argo-postgres-config + key: username + passwordSecret: + name: argo-postgres-config + key: password metricsConfig: | enabled: true path: /metrics diff --git a/test/e2e/sqldb_memoize_test.go b/test/e2e/sqldb_memoize_test.go new file mode 100644 index 000000000000..0a8b8cc16026 --- /dev/null +++ b/test/e2e/sqldb_memoize_test.go @@ -0,0 +1,148 @@ +//go:build sqldbmemoize + +package e2e + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + wfv1 "github.com/argoproj/argo-workflows/v4/pkg/apis/workflow/v1alpha1" + "github.com/argoproj/argo-workflows/v4/test/e2e/fixtures" + "github.com/argoproj/argo-workflows/v4/util/logging" + memodb "github.com/argoproj/argo-workflows/v4/util/memo/db" + "github.com/argoproj/argo-workflows/v4/util/sqldb" +) + +type SQLDBMemoizeSuite struct { + fixtures.E2ESuite +} + +// memoWorkflow builds a workflow spec with a unique cache key per test run to avoid +// stale cache hits from previous runs. +func memoWorkflow(cacheKey string) string { + return fmt.Sprintf(`apiVersion: argoproj.io/v1alpha1 +kind: Workflow +metadata: + generateName: sqldb-memoize- +spec: + entrypoint: hello + templates: + - name: hello + steps: + - - name: run1 + template: memoized + arguments: + parameters: [{name: message, value: "%s"}] + - - name: run2 + template: memoized + arguments: + parameters: [{name: message, value: "%s"}] + - name: memoized + inputs: + parameters: + - name: message + memoize: + key: "{{inputs.parameters.message}}" + maxAge: "10m" + cache: + configMap: + name: sqldb-memo-cache + container: + image: argoproj/argosay:v2 + command: [echo] + args: ["{{inputs.parameters.message}}"] +`, cacheKey, cacheKey) +} + +func (s *SQLDBMemoizeSuite) TestSQLDBMemoize() { + s.Require().NotNil(s.Config.Memoization, "memoization DB must be configured for SQL cache test") + + ctx := logging.TestContext(s.T().Context()) + + // Use a unique key so each test run starts with a cold cache. + cacheKey := fmt.Sprintf("hello-sqldb-%d", time.Now().UnixNano()) + + // Submit the workflow and wait for it to succeed. + s.Given(). + Workflow(memoWorkflow(cacheKey)). + When(). + SubmitWorkflow(). + WaitForWorkflow(fixtures.ToBeSucceeded). + Then(). + ExpectWorkflow(func(t *testing.T, _ *metav1.ObjectMeta, status *wfv1.WorkflowStatus) { + memoHit := false + memoSaved := false + for _, node := range status.Nodes { + if node.MemoizationStatus == nil { + continue + } + if node.MemoizationStatus.Hit { + memoHit = true + } else { + memoSaved = true + } + } + assert.True(t, memoSaved, "expected at least one node to save to the cache") + assert.True(t, memoHit, "expected at least one node to hit the cache") + }) + + // Also verify the entry landed in the configured SQL database, not a ConfigMap. + s.assertDBCacheEntry(ctx, cacheKey) + s.assertNoConfigMap(ctx, "sqldb-memo-cache") +} + +// assertDBCacheEntry checks the configured memoization table directly. +func (s *SQLDBMemoizeSuite) assertDBCacheEntry(ctx context.Context, key string) { + memoCfg := s.Config.Memoization + // E2E tests connect to the configured SQL backend via a port-forward on localhost. + cfg := *memoCfg + if cfg.PostgreSQL != nil { + pg := *cfg.PostgreSQL + pg.Host = "localhost" + cfg.PostgreSQL = &pg + } + if cfg.MySQL != nil { + my := *cfg.MySQL + my.Host = "localhost" + cfg.MySQL = &my + } + + session, _, err := sqldb.CreateDBSession(ctx, s.KubeClient, fixtures.Namespace, cfg.DBConfig) + s.Require().NoError(err, "could not connect to memoization DB") + defer session.Close() + + tableName := memodb.TableName(&cfg) + + var count int + row, err := session.SQL(). + QueryRow(fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE namespace = ? AND cache_name = ? AND cache_key = ?`, tableName), + fixtures.Namespace, "sqldb-memo-cache", key) + s.Require().NoError(err) + s.Require().NoError(row.Scan(&count)) + s.Equal(1, count, "expected exactly one cache entry in the database for key %q", key) + + // Also verify outputs are stored as valid JSON. + var outputs string + row, err = session.SQL(). + QueryRow(fmt.Sprintf(`SELECT outputs FROM %s WHERE namespace = ? AND cache_name = ? AND cache_key = ?`, tableName), + fixtures.Namespace, "sqldb-memo-cache", key) + s.Require().NoError(err) + s.Require().NoError(row.Scan(&outputs)) + s.NotEmpty(outputs) +} + +// assertNoConfigMap verifies the controller did NOT fall back to creating a ConfigMap cache. +func (s *SQLDBMemoizeSuite) assertNoConfigMap(ctx context.Context, name string) { + _, err := s.KubeClient.CoreV1().ConfigMaps(fixtures.Namespace).Get(ctx, name, metav1.GetOptions{}) + s.Error(err, "ConfigMap %q should not exist when SQL memoization is configured", name) +} + +func TestSQLDBMemoizeSuite(t *testing.T) { + suite.Run(t, new(SQLDBMemoizeSuite)) +} diff --git a/util/file/watch.go b/util/file/watch.go index a68dfbb7c4aa..6238c746cd56 100644 --- a/util/file/watch.go +++ b/util/file/watch.go @@ -113,7 +113,7 @@ func watchFilePoll(ctx context.Context, path string, onChange func()) error { } return err } - if last == nil || fi.ModTime() != last.ModTime() || fi.Size() != last.Size() { + if last == nil || !fi.ModTime().Equal(last.ModTime()) || fi.Size() != last.Size() { onChange() last = fi } diff --git a/util/memo/db/config.go b/util/memo/db/config.go new file mode 100644 index 000000000000..8fb58bc0aac1 --- /dev/null +++ b/util/memo/db/config.go @@ -0,0 +1,119 @@ +package db + +import ( + "context" + "fmt" + "hash/fnv" + "regexp" + + "k8s.io/client-go/kubernetes" + + "github.com/argoproj/argo-workflows/v4/config" + "github.com/argoproj/argo-workflows/v4/util/logging" + "github.com/argoproj/argo-workflows/v4/util/sqldb" +) + +const ( + defaultTableName = "cache_entries" + versionTable = "memoization_schema_history" +) + +var validTableName = regexp.MustCompile(`^[A-Za-z0-9_]+$`) + +// Config holds resolved configuration for database-backed memoization. +type Config struct { + TableName string + SkipMigration bool +} + +func TableName(cfg *config.MemoizationConfig) string { + if cfg == nil || cfg.TableName == "" { + return defaultTableName + } + return cfg.TableName +} + +func validateTableName(tableName string) error { + if !validTableName.MatchString(tableName) { + return fmt.Errorf("invalid table name %q: must match [A-Za-z0-9_]+", tableName) + } + return nil +} + +// memoizationVersionTableName returns the schema history table name for the given memoization +// cache table. The default table name uses a fixed well-known name; custom table names get a +// deterministic hash suffix to avoid collisions. +func memoizationVersionTableName(tableName string) string { + if tableName == defaultTableName { + return versionTable + } + hasher := fnv.New64a() + _, _ = hasher.Write([]byte(tableName)) + return fmt.Sprintf("memoization_schema_history_%x", hasher.Sum64()) +} + +// memoizationExpiresAtIndexName returns the name of the expires_at index for the given +// memoization cache table, using a deterministic hash suffix to avoid collisions across +// multiple cache tables. +func memoizationExpiresAtIndexName(tableName string) string { + hasher := fnv.New64a() + _, _ = hasher.Write([]byte(tableName)) + return fmt.Sprintf("memoization_expires_at_%x", hasher.Sum64()) +} + +// ConfigFromConfig converts a controller MemoizationConfig (with DB credentials, connection +// settings, etc.) into the smaller Config struct used by the migration and query layers. +// Returns sensible defaults when cfg is nil. +func ConfigFromConfig(cfg *config.MemoizationConfig) Config { + if cfg == nil { + return Config{TableName: defaultTableName} + } + return Config{ + TableName: TableName(cfg), + SkipMigration: cfg.SkipMigration, + } +} + +// SessionProxyFromConfig creates a SessionProxy from a MemoizationConfig, returning nil and logging +// an error if the connection cannot be established. Callers that receive nil should decide how to +// degrade memoization without crashing the controller. +func SessionProxyFromConfig(ctx context.Context, kubectlConfig kubernetes.Interface, namespace string, cfg *config.MemoizationConfig) *sqldb.SessionProxy { + if cfg == nil { + return nil + } + sessionProxy, err := sqldb.NewSessionProxy(ctx, sqldb.SessionProxyConfig{ + KubectlConfig: kubectlConfig, + Namespace: namespace, + DBConfig: cfg.DBConfig, + }) + if err != nil { + log := logging.RequireLoggerFromContext(ctx) + log.WithError(err).Error(ctx, "unable to create memoization database connection") + return nil + } + sqldb.ConfigureDBSession(sessionProxy.Session(), cfg.ConnectionPool) + return sessionProxy +} + +// Migrate runs database migrations for the memoization cache table. It is a no-op when +// cfg.SkipMigration is true. Returns an error if migration fails; callers should decide how to +// degrade memoization without crashing the controller. +func Migrate(ctx context.Context, sessionProxy *sqldb.SessionProxy, cfg Config) error { + if sessionProxy == nil { + return nil + } + logger := logging.RequireLoggerFromContext(ctx) + if err := validateTableName(cfg.TableName); err != nil { + return err + } + if cfg.SkipMigration { + logger.Info(ctx, "Memoization db migration skipped") + return nil + } + logger.Info(ctx, "Running memoization db migration") + if err := migrate(ctx, sessionProxy.Session(), sessionProxy.DBType(), cfg.TableName); err != nil { + return err + } + logger.Info(ctx, "Memoization db migration complete") + return nil +} diff --git a/util/memo/db/config_test.go b/util/memo/db/config_test.go new file mode 100644 index 000000000000..3dd07a638dd1 --- /dev/null +++ b/util/memo/db/config_test.go @@ -0,0 +1,23 @@ +package db_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/argoproj/argo-workflows/v4/config" + memodb "github.com/argoproj/argo-workflows/v4/util/memo/db" +) + +func TestTableNameDefaultsAndOverrides(t *testing.T) { + assert.Equal(t, "cache_entries", memodb.TableName(nil)) + assert.Equal(t, "cache_entries", memodb.TableName(&config.MemoizationConfig{})) + assert.Equal(t, "custom_cache_entries", memodb.TableName(&config.MemoizationConfig{TableName: "custom_cache_entries"})) +} + +func TestNewQueriesRejectsInvalidTableName(t *testing.T) { + queries, err := memodb.NewQueries("invalid-table-name", nil) + require.Error(t, err) + assert.Nil(t, queries) +} diff --git a/util/memo/db/migrate.go b/util/memo/db/migrate.go new file mode 100644 index 000000000000..45ff46fb7443 --- /dev/null +++ b/util/memo/db/migrate.go @@ -0,0 +1,43 @@ +package db + +import ( + "context" + + "github.com/upper/db/v4" + + "github.com/argoproj/argo-workflows/v4/util/sqldb" +) + +func migrate(ctx context.Context, session db.Session, dbType sqldb.DBType, tableName string) error { + if err := validateTableName(tableName); err != nil { + return err + } + return sqldb.Migrate(ctx, session, dbType, memoizationVersionTableName(tableName), []sqldb.Change{ + // MySQL: use LONGTEXT for outputs (TEXT is 64KB). + // Postgres: use text for outputs (no size limit). + // Varchar sizes chosen to keep composite PK within InnoDB's 3072-byte limit with utf8mb4: + // (64 + 128 + 256) * 4 = 1792 bytes. + sqldb.ByType(dbType, sqldb.TypedChanges{ + sqldb.Postgres: sqldb.AnsiSQLChange(`create table if not exists ` + tableName + ` ( + namespace varchar(64) not null, + cache_name varchar(128) not null, + cache_key varchar(256) not null, + node_id text not null, + outputs text not null, + created_at timestamp not null, + expires_at timestamp not null, + primary key (namespace, cache_name, cache_key) +)`), + sqldb.MySQL: sqldb.AnsiSQLChange("create table if not exists " + tableName + " (" + + "namespace varchar(64) not null, " + + "cache_name varchar(128) not null, " + + "cache_key varchar(256) not null, " + + "node_id text not null, " + + "outputs longtext not null, " + + "created_at timestamp not null, " + + "expires_at timestamp not null, " + + "primary key (namespace, cache_name, cache_key))"), + }), + sqldb.AnsiSQLChange(`create index ` + memoizationExpiresAtIndexName(tableName) + ` on ` + tableName + ` (expires_at)`), + }) +} diff --git a/util/memo/db/queries.go b/util/memo/db/queries.go new file mode 100644 index 000000000000..cda000b22a7e --- /dev/null +++ b/util/memo/db/queries.go @@ -0,0 +1,181 @@ +package db + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/upper/db/v4" + + wfv1 "github.com/argoproj/argo-workflows/v4/pkg/apis/workflow/v1alpha1" + "github.com/argoproj/argo-workflows/v4/util/sqldb" +) + +const ( + colNamespace = "namespace" + colCacheName = "cache_name" + colCacheKey = "cache_key" + colExpiresAt = "expires_at" +) + +// CacheRecord is the database row for a single memoization cache entry. +type CacheRecord struct { + Namespace string `db:"namespace"` + CacheName string `db:"cache_name"` + CacheKey string `db:"cache_key"` + NodeID string `db:"node_id"` + Outputs string `db:"outputs"` // JSON + CreatedAt time.Time `db:"created_at"` + ExpiresAt time.Time `db:"expires_at"` +} + +// MemoizationDB is the interface for database-backed memoization cache operations. +type MemoizationDB interface { + Load(ctx context.Context, namespace, cacheName, cacheKey string) (*CacheRecord, error) + Save(ctx context.Context, namespace, cacheName, cacheKey, nodeID string, outputs *wfv1.Outputs, maxAgeSeconds int64) error + Prune(ctx context.Context) (int64, error) + IsEnabled() bool +} + +// NullMemoizationDB is a no-op implementation used when database memoization is disabled. +var NullMemoizationDB MemoizationDB = &nullMemoizationDB{} + +type nullMemoizationDB struct{} + +func (n *nullMemoizationDB) Load(context.Context, string, string, string) (*CacheRecord, error) { + return nil, nil +} + +func (n *nullMemoizationDB) Save(context.Context, string, string, string, string, *wfv1.Outputs, int64) error { + return nil +} + +func (n *nullMemoizationDB) Prune(context.Context) (int64, error) { + return 0, nil +} + +func (n *nullMemoizationDB) IsEnabled() bool { + return false +} + +var _ MemoizationDB = &queries{} + +// queries provides database operations for the memoization cache table. +type queries struct { + tableName string + sessionProxy *sqldb.SessionProxy +} + +func NewQueries(tableName string, sessionProxy *sqldb.SessionProxy) (MemoizationDB, error) { + if err := validateTableName(tableName); err != nil { + return nil, err + } + return &queries{tableName: tableName, sessionProxy: sessionProxy}, nil +} + +func (q *queries) IsEnabled() bool { + return true +} + +func cacheRecordCond(record *CacheRecord) db.Cond { + return db.Cond{ + colNamespace: record.Namespace, + colCacheName: record.CacheName, + colCacheKey: record.CacheKey, + } +} + +func cacheRecordUpdates(record *CacheRecord) map[string]any { + return map[string]any{ + "node_id": record.NodeID, + "outputs": record.Outputs, + "created_at": record.CreatedAt, + "expires_at": record.ExpiresAt, + } +} + +func isDuplicateKeyError(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "duplicate key") || strings.Contains(err.Error(), "Duplicate entry") +} + +func saveRecord(sess db.Session, tableName string, record *CacheRecord) error { + collection := sess.Collection(tableName) + _, err := collection.Insert(record) + if err == nil { + return nil + } + if !isDuplicateKeyError(err) { + return err + } + return collection.Find(cacheRecordCond(record)).Update(cacheRecordUpdates(record)) +} + +// Load retrieves the outputs for the given cache key. +// Returns nil when the entry does not exist or has expired. +func (q *queries) Load(ctx context.Context, namespace, cacheName, cacheKey string) (*CacheRecord, error) { + var r CacheRecord + now := time.Now().UTC() + err := q.sessionProxy.With(ctx, func(sess db.Session) error { + return sess.SQL(). + SelectFrom(q.tableName). + Where(db.Cond{colNamespace: namespace}). + And(db.Cond{colCacheName: cacheName}). + And(db.Cond{colCacheKey: cacheKey}). + And(db.Cond{colExpiresAt + " >": now}). + One(&r) + }) + if errors.Is(err, db.ErrNoMoreRows) { + return nil, nil + } + if err != nil { + return nil, err + } + return &r, nil +} + +// Prune deletes cache entries whose expires_at has elapsed. It is called +// periodically by the controller to bound the size of the configured memoization cache table. +func (q *queries) Prune(ctx context.Context) (int64, error) { + now := time.Now().UTC() + var n int64 + err := q.sessionProxy.With(ctx, func(sess db.Session) error { + result, err := sess.SQL(). + DeleteFrom(q.tableName). + Where(db.Cond{colExpiresAt + " <": now}). + Exec() + if err != nil { + return err + } + n, err = result.RowsAffected() + return err + }) + return n, err +} + +func (q *queries) Save( + ctx context.Context, namespace, cacheName, cacheKey, nodeID string, outputs *wfv1.Outputs, maxAgeSeconds int64) error { + outputsJSON, err := json.Marshal(outputs) + if err != nil { + return fmt.Errorf("unable to marshal memoization outputs: %w", err) + } + now := time.Now().UTC() + expiresAt := now.Add(time.Duration(maxAgeSeconds) * time.Second) + record := &CacheRecord{ + Namespace: namespace, + CacheName: cacheName, + CacheKey: cacheKey, + NodeID: nodeID, + Outputs: string(outputsJSON), + CreatedAt: now, + ExpiresAt: expiresAt, + } + return q.sessionProxy.With(ctx, func(sess db.Session) error { + return saveRecord(sess, q.tableName, record) + }) +} diff --git a/util/memo/db/queries_internal_test.go b/util/memo/db/queries_internal_test.go new file mode 100644 index 000000000000..e8aa4050c21f --- /dev/null +++ b/util/memo/db/queries_internal_test.go @@ -0,0 +1,50 @@ +package db + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + upperdb "github.com/upper/db/v4" +) + +func TestCacheRecordCond(t *testing.T) { + record := &CacheRecord{ + Namespace: "my-ns", + CacheName: "my-cache", + CacheKey: "my-key", + NodeID: "ignored", + } + + assert.Equal(t, upperdb.Cond{ + colNamespace: "my-ns", + colCacheName: "my-cache", + colCacheKey: "my-key", + }, cacheRecordCond(record)) +} + +func TestCacheRecordUpdates(t *testing.T) { + now := time.Unix(100, 0).UTC() + expiresAt := time.Unix(200, 0).UTC() + record := &CacheRecord{ + NodeID: "node-1", + Outputs: `{"result":"ok"}`, + CreatedAt: now, + ExpiresAt: expiresAt, + } + + assert.Equal(t, map[string]any{ + "node_id": "node-1", + "outputs": `{"result":"ok"}`, + "created_at": now, + "expires_at": expiresAt, + }, cacheRecordUpdates(record)) +} + +func TestIsDuplicateKeyError(t *testing.T) { + assert.True(t, isDuplicateKeyError(errors.New("pq: duplicate key value violates unique constraint"))) + assert.True(t, isDuplicateKeyError(errors.New("Error 1062: Duplicate entry 'x' for key 'PRIMARY'"))) + assert.False(t, isDuplicateKeyError(errors.New("some other error"))) + assert.False(t, isDuplicateKeyError(nil)) +} diff --git a/util/memo/db/queries_test.go b/util/memo/db/queries_test.go new file mode 100644 index 000000000000..496f90ef24b0 --- /dev/null +++ b/util/memo/db/queries_test.go @@ -0,0 +1,343 @@ +//go:build !windows + +package db_test + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + testcontainers "github.com/testcontainers/testcontainers-go" + testmysql "github.com/testcontainers/testcontainers-go/modules/mysql" + testpostgres "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" + + "github.com/argoproj/argo-workflows/v4/config" + wfv1 "github.com/argoproj/argo-workflows/v4/pkg/apis/workflow/v1alpha1" + "github.com/argoproj/argo-workflows/v4/util/logging" + memodb "github.com/argoproj/argo-workflows/v4/util/memo/db" + "github.com/argoproj/argo-workflows/v4/util/sqldb" +) + +const ( + testDBName = "memotest" + testDBUser = "user" + testDBPassword = "pass" + testNamespace = "default" + testCacheName = "my-cache" +) + +var testTableName = memodb.TableName(nil) + +// setupPostgres starts a throwaway Postgres container and returns a migrated SessionProxy. +func setupPostgres(ctx context.Context, t *testing.T) *sqldb.SessionProxy { + t.Helper() + pg, err := testpostgres.Run(ctx, + "postgres:17.4-alpine", + testpostgres.WithDatabase(testDBName), + testpostgres.WithUsername(testDBUser), + testpostgres.WithPassword(testDBPassword), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2). + WithStartupTimeout(30*time.Second), + ), + ) + require.NoError(t, err) + t.Cleanup(func() { + if termErr := testcontainers.TerminateContainer(pg); termErr != nil { + t.Logf("failed to terminate postgres container: %s", termErr) + } + }) + + host, err := pg.Host(ctx) + require.NoError(t, err) + portStr, err := pg.MappedPort(ctx, "5432/tcp") + require.NoError(t, err) + port, err := strconv.Atoi(portStr.Port()) + require.NoError(t, err) + + dbCfg := config.DBConfig{ + PostgreSQL: &config.PostgreSQLConfig{ + DatabaseConfig: config.DatabaseConfig{ + Host: host, + Port: port, + Database: testDBName, + }, + }, + } + sp, err := sqldb.NewSessionProxy(ctx, sqldb.SessionProxyConfig{ + DBConfig: dbCfg, + Username: testDBUser, + Password: testDBPassword, + MaxRetries: 5, + BaseDelay: 200 * time.Millisecond, + MaxDelay: 10 * time.Second, + }) + require.NoError(t, err) + t.Cleanup(func() { _ = sp.Close() }) + + memoCfg := &config.MemoizationConfig{ + DBConfig: dbCfg, + } + require.NoError(t, memodb.Migrate(ctx, sp, memodb.ConfigFromConfig(memoCfg))) + return sp +} + +// setupMySQL starts a throwaway MySQL container and returns a migrated SessionProxy. +func setupMySQL(ctx context.Context, t *testing.T) *sqldb.SessionProxy { + t.Helper() + my, err := testmysql.Run(ctx, + "mysql:8.4.5", + testmysql.WithDatabase(testDBName), + testmysql.WithUsername(testDBUser), + testmysql.WithPassword(testDBPassword), + ) + require.NoError(t, err) + t.Cleanup(func() { + if termErr := testcontainers.TerminateContainer(my); termErr != nil { + t.Logf("failed to terminate mysql container: %s", termErr) + } + }) + + host, err := my.Host(ctx) + require.NoError(t, err) + portStr, err := my.MappedPort(ctx, "3306/tcp") + require.NoError(t, err) + port, err := strconv.Atoi(portStr.Port()) + require.NoError(t, err) + + dbCfg := config.DBConfig{ + MySQL: &config.MySQLConfig{ + DatabaseConfig: config.DatabaseConfig{ + Host: host, + Port: port, + Database: testDBName, + }, + }, + } + sp, err := sqldb.NewSessionProxy(ctx, sqldb.SessionProxyConfig{ + DBConfig: dbCfg, + Username: testDBUser, + Password: testDBPassword, + MaxRetries: 5, + BaseDelay: 200 * time.Millisecond, + MaxDelay: 10 * time.Second, + }) + require.NoError(t, err) + t.Cleanup(func() { _ = sp.Close() }) + + memoCfg := &config.MemoizationConfig{ + DBConfig: dbCfg, + } + require.NoError(t, memodb.Migrate(ctx, sp, memodb.ConfigFromConfig(memoCfg))) + return sp +} + +func newQueries(t *testing.T, sp *sqldb.SessionProxy) memodb.MemoizationDB { + t.Helper() + q, err := memodb.NewQueries(testTableName, sp) + require.NoError(t, err) + return q +} + +func sampleOutputs(message string) *wfv1.Outputs { + return &wfv1.Outputs{ + Parameters: []wfv1.Parameter{ + {Name: "result", Value: wfv1.AnyStringPtr(message)}, + }, + } +} + +func TestQueriesSaveAndLoad(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupPostgres(ctx, t) + q := newQueries(t, sp) + + // Load returns nil when no entry exists. + rec, err := q.Load(ctx, testNamespace, testCacheName, "key1") + require.NoError(t, err) + assert.Nil(t, rec, "expected nil for missing key") + + // Save an entry and load it back. + require.NoError(t, q.Save(ctx, testNamespace, testCacheName, "key1", "node-abc", sampleOutputs("hello"), 2592000)) + rec, err = q.Load(ctx, testNamespace, testCacheName, "key1") + require.NoError(t, err) + require.NotNil(t, rec) + assert.Equal(t, "node-abc", rec.NodeID) + assert.Contains(t, rec.Outputs, "hello") +} + +func TestQueriesNamespaceIsolation(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupPostgres(ctx, t) + q := newQueries(t, sp) + + // Save the same cache_name+cache key in two different namespaces. + require.NoError(t, q.Save(ctx, "ns-a", testCacheName, "shared-key", "node-a", sampleOutputs("from-a"), 2592000)) + require.NoError(t, q.Save(ctx, "ns-b", testCacheName, "shared-key", "node-b", sampleOutputs("from-b"), 2592000)) + + // Each namespace should see its own entry. + recA, err := q.Load(ctx, "ns-a", testCacheName, "shared-key") + require.NoError(t, err) + require.NotNil(t, recA) + assert.Equal(t, "node-a", recA.NodeID) + assert.Contains(t, recA.Outputs, "from-a") + + recB, err := q.Load(ctx, "ns-b", testCacheName, "shared-key") + require.NoError(t, err) + require.NotNil(t, recB) + assert.Equal(t, "node-b", recB.NodeID) + assert.Contains(t, recB.Outputs, "from-b") +} + +func TestQueriesSaveReplaces(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupPostgres(ctx, t) + q := newQueries(t, sp) + + require.NoError(t, q.Save(ctx, testNamespace, testCacheName, "key3", "node-old", sampleOutputs("old"), 2592000)) + require.NoError(t, q.Save(ctx, testNamespace, testCacheName, "key3", "node-new", sampleOutputs("new"), 2592000)) + + rec, err := q.Load(ctx, testNamespace, testCacheName, "key3") + require.NoError(t, err) + require.NotNil(t, rec) + assert.Equal(t, "node-new", rec.NodeID) + assert.Contains(t, rec.Outputs, "new") +} + +func TestQueriesLoadSkipsExpiredEntries(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupPostgres(ctx, t) + q := newQueries(t, sp) + + require.NoError(t, q.Save(ctx, testNamespace, testCacheName, "expired-key", "node-old", sampleOutputs("old"), 2592000)) + + _, err := sp.Session().SQL(). + ExecContext(ctx, `UPDATE `+testTableName+` SET expires_at = $1 WHERE cache_key = $2`, time.Now().Add(-10*time.Second), "expired-key") + require.NoError(t, err) + + rec, err := q.Load(ctx, testNamespace, testCacheName, "expired-key") + require.NoError(t, err) + assert.Nil(t, rec, "expired entries should load as a cache miss") +} + +func TestQueriesPruneRemovesOldEntries(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupPostgres(ctx, t) + q := newQueries(t, sp) + + // Save an entry with a very short max_age (1 second) and one with 30 days. + require.NoError(t, q.Save(ctx, testNamespace, testCacheName, "old-key", "node-old", sampleOutputs("old"), 1)) + require.NoError(t, q.Save(ctx, testNamespace, testCacheName, "new-key", "node-new", sampleOutputs("new"), 2592000)) + + // Backdate old-key's expires_at so it is in the past. + _, err := sp.Session().SQL(). + ExecContext(ctx, `UPDATE `+testTableName+` SET expires_at = $1 WHERE cache_key = $2`, time.Now().Add(-10*time.Second), "old-key") + require.NoError(t, err) + + // Prune — old-key should be deleted (expires_at < now), new-key should survive. + n, err := q.Prune(ctx) + require.NoError(t, err) + assert.EqualValues(t, 1, n, "expected exactly one row pruned") + + old, err := q.Load(ctx, testNamespace, testCacheName, "old-key") + require.NoError(t, err) + assert.Nil(t, old, "old entry should have been pruned") + + fresh, err := q.Load(ctx, testNamespace, testCacheName, "new-key") + require.NoError(t, err) + assert.NotNil(t, fresh, "new entry should still exist") +} + +func TestQueriesPruneKeepsRecentEntries(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupPostgres(ctx, t) + q := newQueries(t, sp) + + require.NoError(t, q.Save(ctx, testNamespace, testCacheName, "recent", "node-1", sampleOutputs("v1"), 2592000)) + + // All entries are recent — nothing should be pruned. + n, err := q.Prune(ctx) + require.NoError(t, err) + assert.EqualValues(t, 0, n, "expected no rows pruned when all entries are fresh") +} + +// MySQL test variants — verify longtext and upsert behavior. + +func TestMySQLSaveAndLoad(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupMySQL(ctx, t) + q := newQueries(t, sp) + + rec, err := q.Load(ctx, testNamespace, testCacheName, "key1") + require.NoError(t, err) + assert.Nil(t, rec, "expected nil for missing key") + + require.NoError(t, q.Save(ctx, testNamespace, testCacheName, "key1", "node-abc", sampleOutputs("hello"), 2592000)) + rec, err = q.Load(ctx, testNamespace, testCacheName, "key1") + require.NoError(t, err) + require.NotNil(t, rec) + assert.Equal(t, "node-abc", rec.NodeID) + assert.Contains(t, rec.Outputs, "hello") +} + +func TestMySQLSaveReplaces(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupMySQL(ctx, t) + q := newQueries(t, sp) + + require.NoError(t, q.Save(ctx, testNamespace, testCacheName, "key3", "node-old", sampleOutputs("old"), 2592000)) + require.NoError(t, q.Save(ctx, testNamespace, testCacheName, "key3", "node-new", sampleOutputs("new"), 2592000)) + + rec, err := q.Load(ctx, testNamespace, testCacheName, "key3") + require.NoError(t, err) + require.NotNil(t, rec) + assert.Equal(t, "node-new", rec.NodeID) + assert.Contains(t, rec.Outputs, "new") +} + +func TestMySQLLoadSkipsExpiredEntries(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupMySQL(ctx, t) + q := newQueries(t, sp) + + require.NoError(t, q.Save(ctx, testNamespace, testCacheName, "expired-key", "node-old", sampleOutputs("old"), 2592000)) + + _, err := sp.Session().SQL(). + ExecContext(ctx, "UPDATE "+testTableName+" SET expires_at = ? WHERE cache_key = ?", time.Now().Add(-10*time.Second), "expired-key") + require.NoError(t, err) + + rec, err := q.Load(ctx, testNamespace, testCacheName, "expired-key") + require.NoError(t, err) + assert.Nil(t, rec, "expired entries should load as a cache miss") +} + +func TestMySQLPruneRemovesOldEntries(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupMySQL(ctx, t) + q := newQueries(t, sp) + + require.NoError(t, q.Save(ctx, testNamespace, testCacheName, "old-key", "node-old", sampleOutputs("old"), 1)) + require.NoError(t, q.Save(ctx, testNamespace, testCacheName, "new-key", "node-new", sampleOutputs("new"), 2592000)) + + // Backdate old-key's expires_at so it is in the past. + _, err := sp.Session().SQL(). + ExecContext(ctx, "UPDATE "+testTableName+" SET expires_at = ? WHERE cache_key = ?", time.Now().Add(-10*time.Second), "old-key") + require.NoError(t, err) + + n, err := q.Prune(ctx) + require.NoError(t, err) + assert.EqualValues(t, 1, n, "expected exactly one row pruned") + + old, err := q.Load(ctx, testNamespace, testCacheName, "old-key") + require.NoError(t, err) + assert.Nil(t, old, "old entry should have been pruned") + + fresh, err := q.Load(ctx, testNamespace, testCacheName, "new-key") + require.NoError(t, err) + assert.NotNil(t, fresh, "new entry should still exist") +} diff --git a/workflow/controller/cache/cache.go b/workflow/controller/cache/cache.go index 4ecfb0e145c0..4a066a5036e5 100644 --- a/workflow/controller/cache/cache.go +++ b/workflow/controller/cache/cache.go @@ -2,7 +2,10 @@ package cache import ( "context" + "fmt" + "os" "regexp" + "strconv" "sync" "time" @@ -10,13 +13,60 @@ import ( "k8s.io/client-go/kubernetes" wfv1 "github.com/argoproj/argo-workflows/v4/pkg/apis/workflow/v1alpha1" + "github.com/argoproj/argo-workflows/v4/util/logging" + memodb "github.com/argoproj/argo-workflows/v4/util/memo/db" ) var cacheKeyRegex = regexp.MustCompile("^[a-zA-Z0-9][-a-zA-Z0-9]*$") +// defaultMaxAgeSeconds is 30 days in seconds, used when maxAge is not specified on the template. +const defaultMaxAgeSeconds int64 = 30 * 24 * 60 * 60 + +// resolvedDefaultMaxAge caches the DEFAULT_MAX_AGE env var so it is only read once. +var resolvedDefaultMaxAge struct { + once sync.Once + secs int64 + err error +} + +// ResolveMaxAgeSeconds converts a template's maxAge duration string to seconds for SQL-backed +// memoization cache entries. If maxAge is empty, it falls back to the DEFAULT_MAX_AGE env var +// (a Go duration string like "720h"), then to 30 days. Returns an error only if the duration +// string is malformed. +func ResolveMaxAgeSeconds(maxAge string) (int64, error) { + if maxAge == "" { + resolvedDefaultMaxAge.once.Do(func() { + envVal := os.Getenv("DEFAULT_MAX_AGE") + if envVal == "" { + resolvedDefaultMaxAge.secs = defaultMaxAgeSeconds + return + } + // Try parsing as a Go duration first (e.g. "720h") + if d, err := time.ParseDuration(envVal); err == nil { + resolvedDefaultMaxAge.secs = int64(d.Seconds()) + return + } + // Fall back to parsing as raw seconds (e.g. "2592000") + if secs, err := strconv.ParseInt(envVal, 10, 64); err == nil { + resolvedDefaultMaxAge.secs = secs + return + } + resolvedDefaultMaxAge.err = fmt.Errorf("invalid DEFAULT_MAX_AGE value %q: must be a Go duration (e.g. 720h) or integer seconds", envVal) + }) + return resolvedDefaultMaxAge.secs, resolvedDefaultMaxAge.err + } + d, err := time.ParseDuration(maxAge) + if err != nil { + return 0, fmt.Errorf("invalid maxAge %q: %w", maxAge, err) + } + return int64(d.Seconds()), nil +} + type MemoizationCache interface { Load(ctx context.Context, key string) (*Entry, error) - Save(ctx context.Context, key string, nodeID string, value *wfv1.Outputs) error + // Save stores the outputs of a completed memoized node. ConfigMap-backed caches ignore maxAge. + // SQL-backed caches use maxAge, or DEFAULT_MAX_AGE when maxAge is empty, to compute expires_at. + Save(ctx context.Context, key string, nodeID string, value *wfv1.Outputs, maxAge string) error } type Entry struct { @@ -51,35 +101,56 @@ func (e *Entry) GetOutputsWithMaxAge(maxAge time.Duration) (*wfv1.Outputs, bool) type cacheFactory struct { caches map[string]MemoizationCache kubeclient kubernetes.Interface - namespace string lock sync.RWMutex + queries memodb.MemoizationDB } type Factory interface { - GetCache(ct Type, name string) MemoizationCache + GetCache(ctx context.Context, ct Type, namespace, name string) MemoizationCache + // SetQueries configures the factory to use database-backed caching with the given + // MemoizationDB. Calling this clears any previously created cache instances + // so they are recreated against the SQL backend. + SetQueries(q memodb.MemoizationDB) } -func NewCacheFactory(ki kubernetes.Interface, ns string) Factory { +func NewCacheFactory(ki kubernetes.Interface) Factory { return &cacheFactory{ - make(map[string]MemoizationCache), - ki, - ns, - sync.RWMutex{}, + caches: make(map[string]MemoizationCache), + kubeclient: ki, } } type Type string const ( - // Only config maps are currently supported for caching + // ConfigMapCache is a cache type identifier used as a key prefix in the cache map. + // When a MemoizationDB is configured, SQL-backed memoization semantics are used instead. ConfigMapCache Type = "ConfigMapCache" ) -// Returns a cache if it exists and creates it otherwise -func (cf *cacheFactory) GetCache(ct Type, name string) MemoizationCache { +// SetQueries configures the factory's memoization backend, clearing any previously +// cached instances so they are recreated against the new backend. A nil MemoizationDB +// selects ConfigMap-backed caching; a non-nil MemoizationDB selects SQL-backed +// memoization semantics, even if the DB implementation is disabled/no-op. +func (cf *cacheFactory) SetQueries(q memodb.MemoizationDB) { + cf.lock.Lock() + defer cf.lock.Unlock() + cf.queries = q + cf.caches = make(map[string]MemoizationCache) +} + +// GetCache returns a cache scoped to the given workflow namespace if it exists and creates it +// otherwise. +func (cf *cacheFactory) GetCache(ctx context.Context, ct Type, namespace, name string) MemoizationCache { + logger := logging.RequireLoggerFromContext(ctx) + if namespace == "" { + logger.WithField("cacheName", name).Error(ctx, "Workflow namespace is required to resolve memoization cache") + return nil + } + cf.lock.RLock() - idx := string(ct) + "." + name + idx := string(ct) + "." + namespace + "." + name if c := cf.caches[idx]; c != nil { cf.lock.RUnlock() return c @@ -95,7 +166,12 @@ func (cf *cacheFactory) GetCache(ct Type, name string) MemoizationCache { switch ct { case ConfigMapCache: - c := NewConfigMapCache(cf.namespace, cf.kubeclient, name) + var c MemoizationCache + if cf.queries != nil { + c = newSQLDBCache(namespace, name, func() memodb.MemoizationDB { return cf.queries }, &cf.lock) + } else { + c = NewConfigMapCache(namespace, cf.kubeclient, name) + } cf.caches[idx] = c return c default: diff --git a/workflow/controller/cache/cache_factory_test.go b/workflow/controller/cache/cache_factory_test.go new file mode 100644 index 000000000000..d8c2255757a9 --- /dev/null +++ b/workflow/controller/cache/cache_factory_test.go @@ -0,0 +1,138 @@ +package cache + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/client-go/kubernetes/fake" + + wfv1 "github.com/argoproj/argo-workflows/v4/pkg/apis/workflow/v1alpha1" + "github.com/argoproj/argo-workflows/v4/util/logging" + memodb "github.com/argoproj/argo-workflows/v4/util/memo/db" +) + +func TestCacheFactoryNamespacesCachesSeparately(t *testing.T) { + ctx := logging.TestContext(t.Context()) + factory := NewCacheFactory(fake.NewSimpleClientset()) + + cacheA := factory.GetCache(ctx, ConfigMapCache, "ns-a", "shared-cache") + cacheARepeat := factory.GetCache(ctx, ConfigMapCache, "ns-a", "shared-cache") + cacheB := factory.GetCache(ctx, ConfigMapCache, "ns-b", "shared-cache") + + require.NotNil(t, cacheA) + require.NotNil(t, cacheARepeat) + require.NotNil(t, cacheB) + assert.Same(t, cacheA, cacheARepeat) + assert.NotSame(t, cacheA, cacheB) + assert.Equal(t, "ns-a", cacheA.(*configMapCache).namespace) + assert.Equal(t, "ns-b", cacheB.(*configMapCache).namespace) +} + +func TestCacheFactoryRequiresNamespace(t *testing.T) { + ctx := logging.TestContext(t.Context()) + factory := NewCacheFactory(fake.NewSimpleClientset()) + + cache := factory.GetCache(ctx, ConfigMapCache, "", "shared-cache") + assert.Nil(t, cache) +} + +type testMemoizationDB struct { + enabled bool + saveCalls atomic.Int32 + loadStart chan struct{} + loadBlock chan struct{} +} + +func (t *testMemoizationDB) Load(context.Context, string, string, string) (*memodb.CacheRecord, error) { + if t.loadStart != nil { + close(t.loadStart) + } + if t.loadBlock != nil { + <-t.loadBlock + } + return nil, nil +} + +func (t *testMemoizationDB) Save(context.Context, string, string, string, string, *wfv1.Outputs, int64) error { + t.saveCalls.Add(1) + return nil +} + +func (*testMemoizationDB) Prune(context.Context) (int64, error) { + return 0, nil +} + +func (t *testMemoizationDB) IsEnabled() bool { + return t.enabled +} + +func TestCacheFactoryStaleSQLCacheNoopsAfterDisable(t *testing.T) { + ctx := logging.TestContext(t.Context()) + factory := NewCacheFactory(fake.NewSimpleClientset()).(*cacheFactory) + queries := &testMemoizationDB{enabled: true} + factory.SetQueries(queries) + + cache := factory.GetCache(ctx, ConfigMapCache, "default", "shared-cache") + require.NotNil(t, cache) + + factory.SetQueries(nil) + + require.NoError(t, cache.Save(ctx, "memo-key", "node-1", &wfv1.Outputs{}, "1h")) + assert.Zero(t, queries.saveCalls.Load()) +} + +func TestCacheFactoryDisableWaitsForInflightSQLLoad(t *testing.T) { + ctx := logging.TestContext(t.Context()) + factory := NewCacheFactory(fake.NewSimpleClientset()).(*cacheFactory) + queries := &testMemoizationDB{ + enabled: true, + loadStart: make(chan struct{}), + loadBlock: make(chan struct{}), + } + factory.SetQueries(queries) + + cache := factory.GetCache(ctx, ConfigMapCache, "default", "shared-cache") + require.NotNil(t, cache) + + loadDone := make(chan struct{}) + go func() { + defer close(loadDone) + _, _ = cache.Load(ctx, "memo-key") + }() + + select { + case <-queries.loadStart: + case <-time.After(time.Second): + t.Fatal("expected SQL load to start") + } + + setQueriesDone := make(chan struct{}) + go func() { + factory.SetQueries(nil) + close(setQueriesDone) + }() + + select { + case <-setQueriesDone: + t.Fatal("expected SetQueries to wait for in-flight SQL load") + case <-time.After(50 * time.Millisecond): + } + + close(queries.loadBlock) + + select { + case <-setQueriesDone: + case <-time.After(time.Second): + t.Fatal("expected SetQueries to finish after load completes") + } + + select { + case <-loadDone: + case <-time.After(time.Second): + t.Fatal("expected SQL load goroutine to finish") + } +} diff --git a/workflow/controller/cache/configmap_cache.go b/workflow/controller/cache/configmap_cache.go index d9dc06a96203..2004b75ba36c 100644 --- a/workflow/controller/cache/configmap_cache.go +++ b/workflow/controller/cache/configmap_cache.go @@ -132,7 +132,10 @@ func (c *configMapCache) load(ctx context.Context, key string) (*Entry, error) { return &entry, nil } -func (c *configMapCache) Save(ctx context.Context, key string, nodeID string, value *wfv1.Outputs) error { +// Save stores a memoization entry in a ConfigMap. The final argument is maxAge from the +// MemoizationCache interface; ConfigMap-backed memoization does not support TTL/expiration, +// so this implementation intentionally ignores it. +func (c *configMapCache) Save(ctx context.Context, key string, nodeID string, value *wfv1.Outputs, _ string) error { err := retry.OnError(kwait.Backoff{ Duration: time.Second, Factor: 2, diff --git a/workflow/controller/cache/sqldb_cache.go b/workflow/controller/cache/sqldb_cache.go new file mode 100644 index 000000000000..b4c3824bf2c9 --- /dev/null +++ b/workflow/controller/cache/sqldb_cache.go @@ -0,0 +1,94 @@ +package cache + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + wfv1 "github.com/argoproj/argo-workflows/v4/pkg/apis/workflow/v1alpha1" + memodb "github.com/argoproj/argo-workflows/v4/util/memo/db" +) + +type sqlDBCache struct { + namespace string + name string + getQueries func() memodb.MemoizationDB + lock *sync.RWMutex +} + +func newSQLDBCache(namespace, name string, getQueries func() memodb.MemoizationDB, lock *sync.RWMutex) MemoizationCache { + return &sqlDBCache{ + namespace: namespace, + name: name, + getQueries: getQueries, + lock: lock, + } +} + +// withQueries is necessary to make runtime enable/disable of SQL memoization +// safe, deterministic, and backward-safe for already-created cache instances. +func (c *sqlDBCache) withQueries(fn func(memodb.MemoizationDB) error) error { + if c.lock == nil { + return fn(memodb.NullMemoizationDB) + } + c.lock.RLock() + defer c.lock.RUnlock() + + queries := memodb.NullMemoizationDB + if c.getQueries != nil { + if q := c.getQueries(); q != nil { + queries = q + } + } + return fn(queries) +} + +func (c *sqlDBCache) Load(ctx context.Context, key string) (*Entry, error) { + if !cacheKeyRegex.MatchString(key) { + return nil, fmt.Errorf("invalid cache key: %s", key) + } + var record *memodb.CacheRecord + err := c.withQueries(func(queries memodb.MemoizationDB) error { + if !queries.IsEnabled() { + return nil + } + var err error + record, err = queries.Load(ctx, c.namespace, c.name, key) + return err + }) + if err != nil { + return nil, fmt.Errorf("memoization db load failed: %w", err) + } + if record == nil { + return nil, nil + } + var outputs wfv1.Outputs + if err := json.Unmarshal([]byte(record.Outputs), &outputs); err != nil { + return nil, fmt.Errorf("malformed memoization db entry: could not unmarshal outputs JSON: %w", err) + } + return &Entry{ + NodeID: record.NodeID, + Outputs: &outputs, + CreationTimestamp: metav1.Time{Time: record.CreatedAt}, + LastHitTimestamp: metav1.Time{Time: record.CreatedAt}, + }, nil +} + +func (c *sqlDBCache) Save(ctx context.Context, key string, nodeID string, value *wfv1.Outputs, maxAge string) error { + if !cacheKeyRegex.MatchString(key) { + return fmt.Errorf("invalid cache key: %s", key) + } + return c.withQueries(func(queries memodb.MemoizationDB) error { + if !queries.IsEnabled() { + return nil + } + maxAgeSeconds, err := ResolveMaxAgeSeconds(maxAge) + if err != nil { + return err + } + return queries.Save(ctx, c.namespace, c.name, key, nodeID, value, maxAgeSeconds) + }) +} diff --git a/workflow/controller/cache/sqldb_cache_test.go b/workflow/controller/cache/sqldb_cache_test.go new file mode 100644 index 000000000000..d358d86fcfb9 --- /dev/null +++ b/workflow/controller/cache/sqldb_cache_test.go @@ -0,0 +1,234 @@ +//go:build !windows + +package cache + +import ( + "context" + "encoding/json" + "strconv" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + testcontainers "github.com/testcontainers/testcontainers-go" + testpostgres "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" + + "github.com/argoproj/argo-workflows/v4/config" + wfv1 "github.com/argoproj/argo-workflows/v4/pkg/apis/workflow/v1alpha1" + "github.com/argoproj/argo-workflows/v4/util/logging" + memodb "github.com/argoproj/argo-workflows/v4/util/memo/db" + "github.com/argoproj/argo-workflows/v4/util/sqldb" +) + +const ( + testDBName = "cachetest" + testDBUser = "user" + testDBPassword = "pass" + testNamespace = "default" + testCacheName = "my-cache" +) + +var testTableName = memodb.TableName(nil) + +func setupTestPostgres(ctx context.Context, t *testing.T) *sqldb.SessionProxy { + t.Helper() + pg, err := testpostgres.Run(ctx, + "postgres:17.4-alpine", + testpostgres.WithDatabase(testDBName), + testpostgres.WithUsername(testDBUser), + testpostgres.WithPassword(testDBPassword), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2). + WithStartupTimeout(30*time.Second), + ), + ) + require.NoError(t, err) + t.Cleanup(func() { + if termErr := testcontainers.TerminateContainer(pg); termErr != nil { + t.Logf("failed to terminate postgres container: %s", termErr) + } + }) + + host, err := pg.Host(ctx) + require.NoError(t, err) + portStr, err := pg.MappedPort(ctx, "5432/tcp") + require.NoError(t, err) + port, err := strconv.Atoi(portStr.Port()) + require.NoError(t, err) + + dbCfg := config.DBConfig{ + PostgreSQL: &config.PostgreSQLConfig{ + DatabaseConfig: config.DatabaseConfig{ + Host: host, + Port: port, + Database: testDBName, + }, + }, + } + sp, err := sqldb.NewSessionProxy(ctx, sqldb.SessionProxyConfig{ + DBConfig: dbCfg, + Username: testDBUser, + Password: testDBPassword, + MaxRetries: 5, + BaseDelay: 200 * time.Millisecond, + MaxDelay: 10 * time.Second, + }) + require.NoError(t, err) + t.Cleanup(func() { _ = sp.Close() }) + + memoCfg := &config.MemoizationConfig{ + DBConfig: dbCfg, + } + require.NoError(t, memodb.Migrate(ctx, sp, memodb.ConfigFromConfig(memoCfg))) + return sp +} + +func newTestSQLDBCache(t *testing.T, sp *sqldb.SessionProxy) MemoizationCache { + t.Helper() + queries, err := memodb.NewQueries(testTableName, sp) + require.NoError(t, err) + var lock sync.RWMutex + return newSQLDBCache(testNamespace, testCacheName, func() memodb.MemoizationDB { return queries }, &lock) +} + +func TestSQLDBCacheSaveAndLoad(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupTestPostgres(ctx, t) + + c := newTestSQLDBCache(t, sp) + + // Load returns nil for missing key. + entry, err := c.Load(ctx, "key1") + require.NoError(t, err) + assert.Nil(t, entry) + + // Save and load back. + outputs := &wfv1.Outputs{ + Parameters: []wfv1.Parameter{ + {Name: "result", Value: wfv1.AnyStringPtr("hello")}, + }, + } + require.NoError(t, c.Save(ctx, "key1", "node-abc", outputs, "720h")) + + entry, err = c.Load(ctx, "key1") + require.NoError(t, err) + require.NotNil(t, entry) + assert.Equal(t, "node-abc", entry.NodeID) + assert.True(t, entry.Hit()) + require.NotNil(t, entry.Outputs) + require.Len(t, entry.Outputs.Parameters, 1) + assert.Equal(t, "result", entry.Outputs.Parameters[0].Name) + assert.Equal(t, "hello", entry.Outputs.Parameters[0].Value.String()) + assert.False(t, entry.CreationTimestamp.IsZero()) +} + +func TestSQLDBCacheInvalidKey(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupTestPostgres(ctx, t) + + c := newTestSQLDBCache(t, sp) + + // Keys with invalid characters should be rejected. + _, err := c.Load(ctx, "invalid key!") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid cache key") + + err = c.Save(ctx, "invalid key!", "node-1", &wfv1.Outputs{}, "1h") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid cache key") +} + +func TestSQLDBCacheOutputsRoundTrip(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupTestPostgres(ctx, t) + + c := newTestSQLDBCache(t, sp) + + // Save complex outputs and verify they round-trip through JSON. + outputs := &wfv1.Outputs{ + Parameters: []wfv1.Parameter{ + {Name: "p1", Value: wfv1.AnyStringPtr("v1")}, + {Name: "p2", Value: wfv1.AnyStringPtr("v2")}, + }, + } + require.NoError(t, c.Save(ctx, "complex-key", "node-x", outputs, "1h")) + + entry, err := c.Load(ctx, "complex-key") + require.NoError(t, err) + require.NotNil(t, entry) + + // Verify the round-tripped outputs match by comparing JSON. + originalJSON, _ := json.Marshal(outputs) + loadedJSON, _ := json.Marshal(entry.Outputs) + assert.JSONEq(t, string(originalJSON), string(loadedJSON)) +} + +func TestSQLDBCacheGetOutputsWithMaxAge(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupTestPostgres(ctx, t) + + c := newTestSQLDBCache(t, sp) + + outputs := &wfv1.Outputs{ + Parameters: []wfv1.Parameter{ + {Name: "result", Value: wfv1.AnyStringPtr("cached")}, + }, + } + require.NoError(t, c.Save(ctx, "ttl-key", "node-ttl", outputs, "1h")) + + entry, err := c.Load(ctx, "ttl-key") + require.NoError(t, err) + require.NotNil(t, entry) + + // With a large maxAge, outputs should be returned. + out, ok := entry.GetOutputsWithMaxAge(1 * time.Hour) + assert.True(t, ok) + assert.NotNil(t, out) + + // With a zero maxAge, outputs should be expired. + out, ok = entry.GetOutputsWithMaxAge(0) + assert.False(t, ok) + assert.Nil(t, out) +} + +func TestSQLDBCacheUpsertRefreshesCreatedAt(t *testing.T) { + ctx := logging.TestContext(t.Context()) + sp := setupTestPostgres(ctx, t) + + c := newTestSQLDBCache(t, sp) + + outputs := &wfv1.Outputs{ + Parameters: []wfv1.Parameter{ + {Name: "result", Value: wfv1.AnyStringPtr("v1")}, + }, + } + require.NoError(t, c.Save(ctx, "upsert-key", "node-1", outputs, "1h")) + + entry1, err := c.Load(ctx, "upsert-key") + require.NoError(t, err) + require.NotNil(t, entry1) + createdAt1 := entry1.CreationTimestamp.Time + + // Small delay to ensure timestamps differ. + time.Sleep(10 * time.Millisecond) + + // Re-save with updated outputs. + outputs2 := &wfv1.Outputs{ + Parameters: []wfv1.Parameter{ + {Name: "result", Value: wfv1.AnyStringPtr("v2")}, + }, + } + require.NoError(t, c.Save(ctx, "upsert-key", "node-2", outputs2, "1h")) + + entry2, err := c.Load(ctx, "upsert-key") + require.NoError(t, err) + require.NotNil(t, entry2) + + assert.True(t, entry2.CreationTimestamp.After(createdAt1), + "created_at should be refreshed on upsert: first=%v, second=%v", createdAt1, entry2.CreationTimestamp.Time) + assert.Equal(t, "node-2", entry2.NodeID) +} diff --git a/workflow/controller/cache_test.go b/workflow/controller/cache_test.go index 14c69fa41c64..e01885fdb5bd 100644 --- a/workflow/controller/cache_test.go +++ b/workflow/controller/cache_test.go @@ -94,7 +94,7 @@ func TestConfigMapCacheSave(t *testing.T) { outputs := wfv1.Outputs{} outputs.Parameters = append(outputs.Parameters, MockParam) - err := c.Save(ctx, "hi-there-world", "", &outputs) + err := c.Save(ctx, "hi-there-world", "", &outputs, "") require.NoError(t, err) cm, err := controller.kubeclientset.CoreV1().ConfigMaps("default").Get(ctx, "whalesay-cache", metav1.GetOptions{}) @@ -104,3 +104,15 @@ func TestConfigMapCacheSave(t *testing.T) { wfv1.MustUnmarshal([]byte(cm.Data["hi-there-world"]), &entry) assert.Equal(t, entry.LastHitTimestamp.Time, entry.CreationTimestamp.Time) } + +func TestConfigMapCacheSaveIgnoresInvalidDefaultMaxAge(t *testing.T) { + t.Setenv("DEFAULT_MAX_AGE", "definitely-not-a-duration") + + ctx := logging.TestContext(t.Context()) + cancel, controller := newController(ctx) + defer cancel() + + c := cache.NewConfigMapCache("default", controller.kubeclientset, "whalesay-cache") + outputs := &wfv1.Outputs{} + require.NoError(t, c.Save(ctx, "hi-there-world", "", outputs, "")) +} diff --git a/workflow/controller/config.go b/workflow/controller/config.go index 9a199361c2b0..d58c26c8104c 100644 --- a/workflow/controller/config.go +++ b/workflow/controller/config.go @@ -13,11 +13,37 @@ import ( persist "github.com/argoproj/argo-workflows/v4/persist/sqldb" "github.com/argoproj/argo-workflows/v4/util/instanceid" "github.com/argoproj/argo-workflows/v4/util/logging" + memodb "github.com/argoproj/argo-workflows/v4/util/memo/db" "github.com/argoproj/argo-workflows/v4/util/sqldb" "github.com/argoproj/argo-workflows/v4/workflow/artifactrepositories" + controllercache "github.com/argoproj/argo-workflows/v4/workflow/controller/cache" "github.com/argoproj/argo-workflows/v4/workflow/hydrator" ) +var ( + memoSessionProxyFromConfig = memodb.SessionProxyFromConfig + memoizationMigrate = memodb.Migrate +) + +// resetMemoizationBackend switches memoization query backend state and closes the previous +// database session. It updates both controller-level query access and cache-factory query access, +// then clears wfc.memoSessionProxy. If sessionProxy is nil, it closes the currently tracked +// session; otherwise it closes the explicitly provided one. +func (wfc *WorkflowController) resetMemoizationBackend(ctx context.Context, sessionProxy *sqldb.SessionProxy, cacheQueries memodb.MemoizationDB) { + logger := logging.RequireLoggerFromContext(ctx) + wfc.setMemoizationQueries(cacheQueries) + wfc.cacheFactory.SetQueries(cacheQueries) + if sessionProxy == nil { + sessionProxy = wfc.memoSessionProxy + } + wfc.memoSessionProxy = nil + if sessionProxy != nil { + if err := sessionProxy.Close(); err != nil { + logger.WithError(err).Warn(ctx, "Failed to close memoization database session") + } + } +} + func (wfc *WorkflowController) updateConfig(ctx context.Context) error { logger := logging.RequireLoggerFromContext(ctx) _, err := yaml.Marshal(wfc.Config) @@ -78,6 +104,42 @@ func (wfc *WorkflowController) updateConfig(ctx context.Context) error { logger.Info(ctx, "Persistence configuration disabled") } + memoCfg := wfc.Config.Memoization + if memoCfg != nil { + logger.Info(ctx, "Memoization database configuration enabled") + sessionProxy := wfc.memoSessionProxy + if sessionProxy == nil { + sessionProxy = memoSessionProxyFromConfig(ctx, wfc.kubeclientset, wfc.namespace, memoCfg) + if sessionProxy == nil { + logger.Warn(ctx, "Memoization database unavailable; memoization disabled") + wfc.resetMemoizationBackend(ctx, nil, memodb.NullMemoizationDB) + goto memoizationConfigured + } + } + cfg := memodb.ConfigFromConfig(memoCfg) + if err := memoizationMigrate(ctx, sessionProxy, cfg); err != nil { + logger.WithError(err).Error(ctx, "Memoization database migration failed; memoization disabled") + wfc.resetMemoizationBackend(ctx, sessionProxy, memodb.NullMemoizationDB) + goto memoizationConfigured + } + queries, err := memodb.NewQueries(cfg.TableName, sessionProxy) + if err != nil { + logger.WithError(err).Error(ctx, "Memoization database initialization failed; memoization disabled") + wfc.resetMemoizationBackend(ctx, sessionProxy, memodb.NullMemoizationDB) + goto memoizationConfigured + } + wfc.memoSessionProxy = sessionProxy + wfc.setMemoizationQueries(queries) + wfc.cacheFactory.SetQueries(queries) + } else { + if wfc.memoSessionProxy != nil { + logger.Info(ctx, "Memoization database configuration removed") + } + wfc.resetMemoizationBackend(ctx, nil, nil) + logger.Info(ctx, "Memoization database configuration disabled; using ConfigMap-based caching") + } + +memoizationConfigured: wfc.hydrator = hydrator.New(wfc.offloadNodeStatusRepo) wfc.updateEstimatorFactory(ctx) wfc.rateLimiter = wfc.newRateLimiter() @@ -90,6 +152,11 @@ func (wfc *WorkflowController) updateConfig(ctx context.Context) error { return nil } +// getMemoizationCache returns the memoization cache for the given namespace and name. +func (wfc *WorkflowController) getMemoizationCache(ctx context.Context, namespace, name string) controllercache.MemoizationCache { + return wfc.cacheFactory.GetCache(ctx, controllercache.ConfigMapCache, namespace, name) +} + // initDB inits argo DB tables func (wfc *WorkflowController) initDB(ctx context.Context) error { persistence := wfc.Config.Persistence diff --git a/workflow/controller/config_test.go b/workflow/controller/config_test.go index 6144de86b456..b770e6c1e309 100644 --- a/workflow/controller/config_test.go +++ b/workflow/controller/config_test.go @@ -1,12 +1,23 @@ package controller import ( + "context" + stderrors "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/upper/db/v4" + apierr "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + "github.com/argoproj/argo-workflows/v4/config" + wfv1 "github.com/argoproj/argo-workflows/v4/pkg/apis/workflow/v1alpha1" "github.com/argoproj/argo-workflows/v4/util/logging" + memodb "github.com/argoproj/argo-workflows/v4/util/memo/db" + "github.com/argoproj/argo-workflows/v4/util/sqldb" + controllercache "github.com/argoproj/argo-workflows/v4/workflow/controller/cache" ) func TestUpdateConfig(t *testing.T) { @@ -20,4 +31,101 @@ func TestUpdateConfig(t *testing.T) { assert.NotNil(t, controller.archiveLabelSelector) assert.NotNil(t, controller.wfArchive) assert.NotNil(t, controller.offloadNodeStatusRepo) + assert.Equal(t, memodb.NullMemoizationDB, controller.memoQueries) +} + +func TestUpdateConfigMemoizationSessionFailureDisablesMemoization(t *testing.T) { + ctx := logging.TestContext(t.Context()) + cancel, controller := newController(ctx) + defer cancel() + + controller.Config.Memoization = &config.MemoizationConfig{} + + origMemoSessionProxyFromConfig := memoSessionProxyFromConfig + memoSessionProxyFromConfig = func(context.Context, kubernetes.Interface, string, *config.MemoizationConfig) *sqldb.SessionProxy { + return nil + } + t.Cleanup(func() { + memoSessionProxyFromConfig = origMemoSessionProxyFromConfig + }) + + err := controller.updateConfig(ctx) + require.NoError(t, err) + assert.Nil(t, controller.memoSessionProxy) + assert.Equal(t, memodb.NullMemoizationDB, controller.getMemoizationQueries()) + + cache := controller.getMemoizationCache(ctx, "default", "memo-disabled-cache") + require.NotNil(t, cache) + require.NoError(t, cache.Save(ctx, "memo-key", "", &wfv1.Outputs{}, "")) + + _, err = controller.kubeclientset.CoreV1().ConfigMaps("default").Get(ctx, "memo-disabled-cache", metav1.GetOptions{}) + assert.True(t, apierr.IsNotFound(err)) +} + +type observingCacheFactory struct { + setQueries func(memodb.MemoizationDB) +} + +func (o *observingCacheFactory) GetCache(context.Context, controllercache.Type, string, string) controllercache.MemoizationCache { + return nil +} + +func (o *observingCacheFactory) SetQueries(q memodb.MemoizationDB) { + if o.setQueries != nil { + o.setQueries(q) + } +} + +func TestUpdateConfigMemoizationDisableDetachesCachesBeforeClosingSession(t *testing.T) { + ctx := logging.TestContext(t.Context()) + cancel, controller := newController(ctx) + defer cancel() + + sessionProxy := sqldb.NewSessionProxyFromSession(nil, &config.DBConfig{}, "user", "password") + controller.memoSessionProxy = sessionProxy + + var setQueriesSawClosedProxy bool + controller.cacheFactory = &observingCacheFactory{ + setQueries: func(memodb.MemoizationDB) { + err := sessionProxy.With(ctx, func(db.Session) error { return nil }) + setQueriesSawClosedProxy = err != nil && err.Error() == "session proxy is closed" + }, + } + + err := controller.updateConfig(ctx) + require.NoError(t, err) + assert.False(t, setQueriesSawClosedProxy) + assert.Nil(t, controller.memoSessionProxy) + require.EqualError(t, sessionProxy.With(ctx, func(db.Session) error { return nil }), "session proxy is closed") +} + +func TestUpdateConfigMemoizationMigrationFailureDisablesMemoization(t *testing.T) { + ctx := logging.TestContext(t.Context()) + cancel, controller := newController(ctx) + defer cancel() + + controller.Config.Memoization = &config.MemoizationConfig{} + sessionProxy := sqldb.NewSessionProxyFromSession(nil, &config.DBConfig{}, "user", "password") + controller.memoSessionProxy = sessionProxy + + origMemoizationMigrate := memoizationMigrate + memoizationMigrate = func(context.Context, *sqldb.SessionProxy, memodb.Config) error { + return stderrors.New("boom") + } + t.Cleanup(func() { + memoizationMigrate = origMemoizationMigrate + }) + + err := controller.updateConfig(ctx) + require.NoError(t, err) + assert.Nil(t, controller.memoSessionProxy) + assert.Equal(t, memodb.NullMemoizationDB, controller.getMemoizationQueries()) + require.EqualError(t, sessionProxy.With(ctx, func(db.Session) error { return nil }), "session proxy is closed") + + cache := controller.getMemoizationCache(ctx, "default", "memo-migrate-disabled-cache") + require.NotNil(t, cache) + require.NoError(t, cache.Save(ctx, "memo-key", "", &wfv1.Outputs{}, "")) + + _, err = controller.kubeclientset.CoreV1().ConfigMaps("default").Get(ctx, "memo-migrate-disabled-cache", metav1.GetOptions{}) + assert.True(t, apierr.IsNotFound(err)) } diff --git a/workflow/controller/controller.go b/workflow/controller/controller.go index 02cec7e6be70..0a5bf8723be0 100644 --- a/workflow/controller/controller.go +++ b/workflow/controller/controller.go @@ -46,6 +46,7 @@ import ( "github.com/argoproj/argo-workflows/v4/util/deprecation" "github.com/argoproj/argo-workflows/v4/util/env" "github.com/argoproj/argo-workflows/v4/util/errors" + memodb "github.com/argoproj/argo-workflows/v4/util/memo/db" rbacutil "github.com/argoproj/argo-workflows/v4/util/rbac" "github.com/argoproj/argo-workflows/v4/util/retry" utilsqldb "github.com/argoproj/argo-workflows/v4/util/sqldb" @@ -134,6 +135,9 @@ type WorkflowController struct { throttler sync.Throttler workflowKeyLock syncpkg.KeyLock // used to lock workflows for exclusive modification or access sessionProxy *utilsqldb.SessionProxy + memoSessionProxy *utilsqldb.SessionProxy + memoQueries memodb.MemoizationDB + memoizationLock gosync.RWMutex offloadNodeStatusRepo sqldb.OffloadNodeStatusRepo hydrator hydrator.Interface wfArchive sqldb.WorkflowArchive @@ -209,7 +213,7 @@ func NewWorkflowController(ctx context.Context, restConfig *rest.Config, kubecli cliExecutorLogFormat: executorLogFormat, configController: config.NewController(namespace, configMap, kubeclientset), workflowKeyLock: syncpkg.NewKeyLock(), - cacheFactory: controllercache.NewCacheFactory(kubeclientset, namespace), + cacheFactory: controllercache.NewCacheFactory(kubeclientset), eventRecorderManager: events.NewEventRecorderManager(kubeclientset), progressPatchTickDuration: env.LookupEnvDurationOr(ctx, common.EnvVarProgressPatchTickDuration, 1*time.Minute), progressFileTickDuration: env.LookupEnvDurationOr(ctx, common.EnvVarProgressFileTickDuration, 3*time.Second), @@ -259,6 +263,39 @@ func (wfc *WorkflowController) newThrottler() sync.Throttler { return sync.NewMultiThrottler(wfc.Config.Parallelism, wfc.Config.NamespaceParallelism, f) } +// getMemoizationQueries returns the currently configured memoization query backend. +// If no backend is configured, it returns NullMemoizationDB to provide no-op behavior. +func (wfc *WorkflowController) getMemoizationQueries() memodb.MemoizationDB { + wfc.memoizationLock.RLock() + defer wfc.memoizationLock.RUnlock() + if wfc.memoQueries == nil { + return memodb.NullMemoizationDB + } + return wfc.memoQueries +} + +// setMemoizationQueries updates the memoization query backend used by the controller. +// A nil backend is normalized to NullMemoizationDB to keep callers nil-safe. +func (wfc *WorkflowController) setMemoizationQueries(queries memodb.MemoizationDB) { + if queries == nil { + queries = memodb.NullMemoizationDB + } + wfc.memoizationLock.Lock() + defer wfc.memoizationLock.Unlock() + wfc.memoQueries = queries +} + +// withMemoizationQueries executes fn while holding a read lock on the memoization backend. +// If no backend is configured, fn is invoked with NullMemoizationDB. +func (wfc *WorkflowController) withMemoizationQueries(fn func(memodb.MemoizationDB) error) error { + wfc.memoizationLock.RLock() + defer wfc.memoizationLock.RUnlock() + if wfc.memoQueries == nil { + return fn(memodb.NullMemoizationDB) + } + return fn(wfc.memoQueries) +} + // runGCcontroller runs the workflow garbage collector controller func (wfc *WorkflowController) runGCcontroller(ctx context.Context, workflowTTLWorkers int) { defer runtimeutil.HandleCrashWithContext(ctx, runtimeutil.PanicHandlers...) @@ -404,6 +441,7 @@ func (wfc *WorkflowController) Run(ctx context.Context, wfWorkers, workflowTTLWo go wfc.workflowGarbageCollector(ctx) go wfc.archivedWorkflowGarbageCollector(ctx) + go wfc.memoizationCacheGarbageCollector(ctx) go wfc.runGCcontroller(ctx, workflowTTLWorkers) go wfc.runCronController(ctx, cronWorkflowWorkers) @@ -721,6 +759,53 @@ func (wfc *WorkflowController) archivedWorkflowGarbageCollector(ctx context.Cont } } +func (wfc *WorkflowController) memoizationCacheGarbageCollector(ctx context.Context) { + defer runtimeutil.HandleCrashWithContext(ctx, runtimeutil.PanicHandlers...) + + logger := logging.RequireLoggerFromContext(ctx) + logger = logger.WithField("component", "memo_cache_garbage_collector") + ctx = logging.WithLogger(ctx, logger) + + periodicity := env.LookupEnvDurationOr(ctx, "MEMO_CACHE_GC_PERIOD", 24*time.Hour) + if periodicity <= 0 { + logger.Info(ctx, "MEMO_CACHE_GC_PERIOD is zero or negative - cache GC disabled") + return + } + logger.Info(ctx, "Memoization cache GC goroutine started") + + ticker := time.NewTicker(periodicity) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + var ( + n int64 + active bool + ) + err := wfc.withMemoizationQueries(func(queries memodb.MemoizationDB) error { + if !queries.IsEnabled() { + return nil + } + active = true + logger.Info(ctx, "Performing memoization cache GC") + var err error + n, err = queries.Prune(ctx) + return err + }) + if !active { + continue + } + if err != nil { + logger.WithError(err).Error(ctx, "Failed to prune memoization cache") + } else { + logger.WithField("deleted", n).Info(ctx, "Memoization cache GC complete") + } + } + } +} + func (wfc *WorkflowController) runWorker(ctx context.Context) { defer runtimeutil.HandleCrashWithContext(ctx, runtimeutil.PanicHandlers...) diff --git a/workflow/controller/controller_test.go b/workflow/controller/controller_test.go index 7b24616e755f..93170f93901f 100644 --- a/workflow/controller/controller_test.go +++ b/workflow/controller/controller_test.go @@ -296,7 +296,7 @@ func newController(ctx context.Context, options ...any) (context.CancelFunc, *Wo estimatorFactory: estimation.DummyEstimatorFactory, eventRecorderManager: &testEventRecorderManager{eventRecorder: record.NewFakeRecorder(64)}, archiveLabelSelector: labels.Everything(), - cacheFactory: controllercache.NewCacheFactory(kube, "default"), + cacheFactory: controllercache.NewCacheFactory(kube), progressPatchTickDuration: envutil.LookupEnvDurationOr(ctx, common.EnvVarProgressPatchTickDuration, 1*time.Minute), progressFileTickDuration: envutil.LookupEnvDurationOr(ctx, common.EnvVarProgressFileTickDuration, 3*time.Second), maxStackDepth: maxAllowedStackDepth, diff --git a/workflow/controller/dag.go b/workflow/controller/dag.go index 57fa5651c342..cdd87bd0c6d9 100644 --- a/workflow/controller/dag.go +++ b/workflow/controller/dag.go @@ -15,7 +15,6 @@ import ( "github.com/argoproj/argo-workflows/v4/util/logging" "github.com/argoproj/argo-workflows/v4/util/template" "github.com/argoproj/argo-workflows/v4/workflow/common" - controllercache "github.com/argoproj/argo-workflows/v4/workflow/controller/cache" "github.com/argoproj/argo-workflows/v4/workflow/templateresolution" ) @@ -379,11 +378,17 @@ func (woc *wfOperationCtx) executeDAG(ctx context.Context, nodeName string, tmpl woc.wf.Status.Nodes.Set(ctx, node.ID, *node) } if node.MemoizationStatus != nil { - c := woc.controller.cacheFactory.GetCache(controllercache.ConfigMapCache, node.MemoizationStatus.CacheName) - saveErr := c.Save(ctx, node.MemoizationStatus.Key, node.ID, node.Outputs) - if saveErr != nil { - woc.log.WithField("nodeID", node.ID).WithError(saveErr).Error(ctx, "Failed to save node outputs to cache") - node.Phase = wfv1.NodeError + c := woc.controller.getMemoizationCache(ctx, woc.wf.Namespace, node.MemoizationStatus.CacheName) + switch { + case c == nil: + woc.log.WithField("nodeID", node.ID).Warn(ctx, "Memoization cache unavailable; skipping cache save") + case tmpl.Memoize == nil: + woc.log.WithField("nodeID", node.ID).Warn(ctx, "Node template has no memoize spec; skipping cache save") + default: + if saveErr := c.Save(ctx, node.MemoizationStatus.Key, node.ID, node.Outputs, tmpl.Memoize.MaxAge); saveErr != nil { + woc.log.WithField("nodeID", node.ID).WithError(saveErr).Error(ctx, "Failed to save node outputs to cache") + node.Phase = wfv1.NodeError + } } } diff --git a/workflow/controller/memoization_gc_test.go b/workflow/controller/memoization_gc_test.go new file mode 100644 index 000000000000..548209841350 --- /dev/null +++ b/workflow/controller/memoization_gc_test.go @@ -0,0 +1,67 @@ +package controller + +import ( + "context" + "testing" + "time" + + wfv1 "github.com/argoproj/argo-workflows/v4/pkg/apis/workflow/v1alpha1" + "github.com/argoproj/argo-workflows/v4/util/logging" + memodb "github.com/argoproj/argo-workflows/v4/util/memo/db" +) + +type testMemoizationDB struct { + pruned chan struct{} +} + +func (*testMemoizationDB) Load(context.Context, string, string, string) (*memodb.CacheRecord, error) { + return nil, nil +} + +func (*testMemoizationDB) Save(context.Context, string, string, string, string, *wfv1.Outputs, int64) error { + return nil +} + +func (t *testMemoizationDB) Prune(context.Context) (int64, error) { + select { + case t.pruned <- struct{}{}: + default: + } + return 1, nil +} + +func (*testMemoizationDB) IsEnabled() bool { + return true +} + +func TestMemoizationCacheGarbageCollectorHandlesRuntimeEnable(t *testing.T) { + t.Setenv("MEMO_CACHE_GC_PERIOD", "10ms") + + ctx, cancel := context.WithCancel(logging.TestContext(t.Context())) + defer cancel() + + controller := &WorkflowController{} + done := make(chan struct{}) + go func() { + defer close(done) + controller.memoizationCacheGarbageCollector(ctx) + }() + + time.Sleep(25 * time.Millisecond) + + queries := &testMemoizationDB{pruned: make(chan struct{}, 1)} + controller.setMemoizationQueries(queries) + + select { + case <-queries.pruned: + case <-time.After(250 * time.Millisecond): + t.Fatal("expected memoization cache GC to observe runtime enablement and prune") + } + + cancel() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected memoization cache GC goroutine to stop after context cancellation") + } +} diff --git a/workflow/controller/operator.go b/workflow/controller/operator.go index a12da33dd38d..c53512614c2a 100644 --- a/workflow/controller/operator.go +++ b/workflow/controller/operator.go @@ -1196,12 +1196,22 @@ func (woc *wfOperationCtx) podReconciliation(ctx context.Context) (bool, error) woc.addOutputsToGlobalScope(ctx, newState.Outputs) if newState.MemoizationStatus != nil { if newState.Succeeded() { - c := woc.controller.cacheFactory.GetCache(controllercache.ConfigMapCache, newState.MemoizationStatus.CacheName) - err := c.Save(ctx, newState.MemoizationStatus.Key, newState.ID, newState.Outputs) - if err != nil { - woc.log.WithFields(logging.Fields{"nodeID": newState.ID}).WithError(err).Error(ctx, "Failed to save node outputs to cache") - newState.Phase = wfv1.NodeError - newState.Message = err.Error() + c := woc.controller.getMemoizationCache(ctx, woc.wf.Namespace, newState.MemoizationStatus.CacheName) + if c == nil { + woc.log.WithFields(logging.Fields{"nodeID": newState.ID}).Warn(ctx, "Memoization cache unavailable; skipping cache save") + } else { + nodeTmpl, tmplErr := woc.GetNodeTemplate(ctx, newState) + maxAge := "" + if tmplErr != nil { + woc.log.WithFields(logging.Fields{"nodeID": newState.ID}).WithError(tmplErr).Warn(ctx, "Failed to get node template for cache save; using default maxAge") + } else if nodeTmpl != nil && nodeTmpl.Memoize != nil { + maxAge = nodeTmpl.Memoize.MaxAge + } + if err := c.Save(ctx, newState.MemoizationStatus.Key, newState.ID, newState.Outputs, maxAge); err != nil { + woc.log.WithFields(logging.Fields{"nodeID": newState.ID}).WithError(err).Error(ctx, "Failed to save node outputs to cache") + newState.Phase = wfv1.NodeError + newState.Message = err.Error() + } } } } @@ -2228,18 +2238,29 @@ func (woc *wfOperationCtx) executeTemplate(ctx context.Context, nodeName string, // Check memoization cache if the node is about to be created, or was created in the past but is only now allowed to run due to acquiring a lock if processedTmpl.Memoize != nil { + cacheName := "" + if processedTmpl.Memoize.Cache != nil && processedTmpl.Memoize.Cache.ConfigMap != nil { + cacheName = processedTmpl.Memoize.Cache.ConfigMap.Name + } + if cacheName == "" { + cacheErr := fmt.Errorf("memoize.cache.configMap.name is required") + woc.log.WithError(cacheErr).Error(ctx, "Invalid memoize configuration") + errNode := woc.initializeNodeOrMarkError(ctx, node, nodeName, templateScope, orgTmpl, opts.boundaryID, opts.nodeFlag, cacheErr) + return errNode, cacheErr + } if node == nil || unlockedNode { - memoizationCache := woc.controller.cacheFactory.GetCache(controllercache.ConfigMapCache, processedTmpl.Memoize.Cache.ConfigMap.Name) + memoizationCache := woc.controller.getMemoizationCache(ctx, woc.wf.Namespace, cacheName) if memoizationCache == nil { - cacheErr := fmt.Errorf("cache could not be found or created") - woc.log.WithFields(logging.Fields{"cacheName": processedTmpl.Memoize.Cache.ConfigMap.Name}).WithError(cacheErr) - errNode := woc.initializeNodeOrMarkError(ctx, node, nodeName, templateScope, orgTmpl, opts.boundaryID, opts.nodeFlag, cacheErr) - return errNode, cacheErr + woc.log.WithFields(logging.Fields{"cacheName": cacheName}).Warn(ctx, "Memoization cache unavailable; treating as cache miss") } - entry, loadErr := memoizationCache.Load(ctx, processedTmpl.Memoize.Key) - if loadErr != nil { - return woc.initializeNodeOrMarkError(ctx, node, nodeName, templateScope, orgTmpl, opts.boundaryID, opts.nodeFlag, loadErr), loadErr + var entry *controllercache.Entry + if memoizationCache != nil { + var loadErr error + entry, loadErr = memoizationCache.Load(ctx, processedTmpl.Memoize.Key) + if loadErr != nil { + return woc.initializeNodeOrMarkError(ctx, node, nodeName, templateScope, orgTmpl, opts.boundaryID, opts.nodeFlag, loadErr), loadErr + } } hit := entry.Hit() @@ -2263,7 +2284,7 @@ func (woc *wfOperationCtx) executeTemplate(ctx context.Context, nodeName string, memoizationStatus := &wfv1.MemoizationStatus{ Hit: hit, Key: processedTmpl.Memoize.Key, - CacheName: processedTmpl.Memoize.Cache.ConfigMap.Name, + CacheName: cacheName, } if hit { if node == nil { @@ -2808,14 +2829,19 @@ func (woc *wfOperationCtx) initializeExecutableNode(ctx context.Context, nodeNam node.Inputs = executeTmpl.Inputs.DeepCopy() } - // Set the MemoizationStatus - if node.MemoizationStatus == nil && executeTmpl.Memoize != nil { - memoizationStatus := &wfv1.MemoizationStatus{ + // Set the MemoizationStatus only when the user explicitly supplied a key. + // When key is auto-derived, executeTemplate sets MemoizationStatus with the + // computed effectiveKey before reaching this point. + if node.MemoizationStatus == nil && executeTmpl.Memoize != nil && executeTmpl.Memoize.Key != "" { + cacheName := "" + if executeTmpl.Memoize.Cache != nil { + cacheName = executeTmpl.Memoize.Cache.ConfigMap.Name + } + node.MemoizationStatus = &wfv1.MemoizationStatus{ Hit: false, Key: executeTmpl.Memoize.Key, - CacheName: executeTmpl.Memoize.Cache.ConfigMap.Name, + CacheName: cacheName, } - node.MemoizationStatus = memoizationStatus } if nodeType == wfv1.NodeTypeSuspend { diff --git a/workflow/controller/operator_test.go b/workflow/controller/operator_test.go index 1547a90cee5d..f3f537460e9a 100644 --- a/workflow/controller/operator_test.go +++ b/workflow/controller/operator_test.go @@ -6301,6 +6301,7 @@ apiVersion: argoproj.io/v1alpha1 kind: Workflow metadata: name: memoized-workflow-test + namespace: default spec: entrypoint: whalesay arguments: @@ -6389,11 +6390,33 @@ func TestConfigMapCacheLoadOperateMaxAge(t *testing.T) { } } +func TestUnavailableSQLMemoizationBackendTreatsLookupAsCacheMiss(t *testing.T) { + wf := wfv1.MustUnmarshalWorkflow(workflowCachedMaxAge) + ctx := logging.TestContext(t.Context()) + cancel, controller := newController(ctx) + defer cancel() + + _, err := controller.wfclientset.ArgoprojV1alpha1().Workflows(wf.ObjectMeta.Namespace).Create(ctx, wf, metav1.CreateOptions{}) + require.NoError(t, err) + + woc := newWorkflowOperationCtx(ctx, wf, controller) + woc.operate(ctx) + + require.Len(t, woc.wf.Status.Nodes, 1) + for _, node := range woc.wf.Status.Nodes { + assert.Equal(t, wfv1.NodePending, node.Phase) + assert.NotNil(t, node.MemoizationStatus) + assert.False(t, node.MemoizationStatus.Hit) + assert.Nil(t, node.Outputs) + } +} + var workflowStepCachedWithRetryStrategy = ` apiVersion: argoproj.io/v1alpha1 kind: Workflow metadata: name: memoized-workflow-test + namespace: default spec: entrypoint: whalesay arguments: @@ -6428,6 +6451,7 @@ apiVersion: argoproj.io/v1alpha1 kind: Workflow metadata: generateName: memoized-workflow-test + namespace: default spec: entrypoint: main # podGC: diff --git a/workflow/controller/steps.go b/workflow/controller/steps.go index 516850d738d4..3fecd89795cb 100644 --- a/workflow/controller/steps.go +++ b/workflow/controller/steps.go @@ -18,7 +18,6 @@ import ( "github.com/argoproj/argo-workflows/v4/util/logging" "github.com/argoproj/argo-workflows/v4/util/template" "github.com/argoproj/argo-workflows/v4/workflow/common" - controllercache "github.com/argoproj/argo-workflows/v4/workflow/controller/cache" "github.com/argoproj/argo-workflows/v4/workflow/templateresolution" ) @@ -188,11 +187,17 @@ func (woc *wfOperationCtx) executeSteps(ctx context.Context, nodeName string, tm } if node.MemoizationStatus != nil { - c := woc.controller.cacheFactory.GetCache(controllercache.ConfigMapCache, node.MemoizationStatus.CacheName) - err := c.Save(ctx, node.MemoizationStatus.Key, node.ID, node.Outputs) - if err != nil { - woc.log.WithFields(logging.Fields{"nodeID": node.ID}).WithError(err).Error(ctx, "Failed to save node outputs to cache") - node.Phase = wfv1.NodeError + c := woc.controller.getMemoizationCache(ctx, woc.wf.Namespace, node.MemoizationStatus.CacheName) + switch { + case c == nil: + woc.log.WithFields(logging.Fields{"nodeID": node.ID}).Warn(ctx, "Memoization cache unavailable; skipping cache save") + case tmpl.Memoize == nil: + woc.log.WithFields(logging.Fields{"nodeID": node.ID}).Warn(ctx, "Node template has no memoize spec; skipping cache save") + default: + if err := c.Save(ctx, node.MemoizationStatus.Key, node.ID, node.Outputs, tmpl.Memoize.MaxAge); err != nil { + woc.log.WithFields(logging.Fields{"nodeID": node.ID}).WithError(err).Error(ctx, "Failed to save node outputs to cache") + node.Phase = wfv1.NodeError + } } } return woc.markNodePhase(ctx, nodeName, wfv1.NodeSucceeded), nil diff --git a/workflow/controller/taskset.go b/workflow/controller/taskset.go index 86644bc5f7fa..1618468f8839 100644 --- a/workflow/controller/taskset.go +++ b/workflow/controller/taskset.go @@ -15,7 +15,6 @@ import ( wfv1 "github.com/argoproj/argo-workflows/v4/pkg/apis/workflow/v1alpha1" "github.com/argoproj/argo-workflows/v4/util/logging" "github.com/argoproj/argo-workflows/v4/workflow/common" - controllercache "github.com/argoproj/argo-workflows/v4/workflow/controller/cache" ) func (woc *wfOperationCtx) mergePatchTaskSet(ctx context.Context, patch any, subresources ...string) error { @@ -154,10 +153,20 @@ func (woc *wfOperationCtx) reconcileTaskSet(ctx context.Context) error { woc.wf.Status.Nodes.Set(ctx, nodeID, *node) if node.MemoizationStatus != nil && node.Succeeded() { - c := woc.controller.cacheFactory.GetCache(controllercache.ConfigMapCache, node.MemoizationStatus.CacheName) - err := c.Save(ctx, node.MemoizationStatus.Key, node.ID, node.Outputs) - if err != nil { - woc.log.WithFields(logging.Fields{"nodeID": node.ID}).WithError(err).Error(ctx, "Failed to save node outputs to cache") + c := woc.controller.getMemoizationCache(ctx, woc.wf.Namespace, node.MemoizationStatus.CacheName) + if c == nil { + woc.log.WithFields(logging.Fields{"nodeID": node.ID}).Warn(ctx, "Memoization cache unavailable; skipping cache save") + } else { + nodeTmpl, tmplErr := woc.GetNodeTemplate(ctx, node) + maxAge := "" + if tmplErr != nil { + woc.log.WithFields(logging.Fields{"nodeID": node.ID}).WithError(tmplErr).Warn(ctx, "Failed to get node template for cache save; using default maxAge") + } else if nodeTmpl != nil && nodeTmpl.Memoize != nil { + maxAge = nodeTmpl.Memoize.MaxAge + } + if err := c.Save(ctx, node.MemoizationStatus.Key, node.ID, node.Outputs, maxAge); err != nil { + woc.log.WithFields(logging.Fields{"nodeID": node.ID}).WithError(err).Error(ctx, "Failed to save node outputs to cache") + } } } woc.updated = true diff --git a/workflow/sync/multi_throttler_test.go b/workflow/sync/multi_throttler_test.go index b38e976e0d90..0277af9dee31 100644 --- a/workflow/sync/multi_throttler_test.go +++ b/workflow/sync/multi_throttler_test.go @@ -207,12 +207,13 @@ func TestPriorityAcrossNamespaces(t *testing.T) { func TestParallelismUpdate(t *testing.T) { assert := assert.New(t) throttler := NewMultiThrottler(4, 0, func(Key) {}) - throttler.Add("a/0", 0, time.Now()) - throttler.Add("b/0", 0, time.Now()) - throttler.Add("c/0", 0, time.Now()) - throttler.Add("d/0", 0, time.Now()) - throttler.Add("e/0", 0, time.Now()) - throttler.Add("f/0", 0, time.Now()) + base := time.Unix(0, 0) + throttler.Add("a/0", 0, base) + throttler.Add("b/0", 0, base.Add(time.Second)) + throttler.Add("c/0", 0, base.Add(2*time.Second)) + throttler.Add("d/0", 0, base.Add(3*time.Second)) + throttler.Add("e/0", 0, base.Add(4*time.Second)) + throttler.Add("f/0", 0, base.Add(5*time.Second)) assert.True(throttler.Admit("a/0")) assert.True(throttler.Admit("b/0"))