diff --git a/crates/core/src/types/mod.rs b/crates/core/src/types/mod.rs index 05203998..f4c9c32b 100755 --- a/crates/core/src/types/mod.rs +++ b/crates/core/src/types/mod.rs @@ -52,12 +52,13 @@ pub use table::{ DeleteGsiAction, DeleteTableInput, DeleteTableOutput, DescribeLimitsOutput, DescribeTableInput, DescribeTableOutput, DescribeTimeToLiveInput, DescribeTimeToLiveOutput, GlobalSecondaryIndexUpdate, GsiDescription, GsiInput, ListTablesInput, ListTablesOutput, - ListTagsOfResourceInput, ListTagsOfResourceOutput, LsiDescription, LsiInput, Projection, - ProjectionType, ProvisionedThroughput, ProvisionedThroughputDescription, SseDescription, - SseType, StreamSpecification, StreamViewType, TableDescription, TableStatus, Tag, - TagResourceInput, TimeToLiveDescription, TimeToLiveSpecification, - TimeToLiveSpecificationOutput, TimeToLiveStatus, UntagResourceInput, UpdateGsiAction, - UpdateTableInput, UpdateTableOutput, UpdateTimeToLiveInput, UpdateTimeToLiveOutput, + ListTagsOfResourceInput, ListTagsOfResourceOutput, LsiDescription, LsiInput, + OnDemandThroughput, Projection, ProjectionType, ProvisionedThroughput, + ProvisionedThroughputDescription, SseDescription, SseType, StreamSpecification, StreamViewType, + TableDescription, TableStatus, Tag, TagResourceInput, TimeToLiveDescription, + TimeToLiveSpecification, TimeToLiveSpecificationOutput, TimeToLiveStatus, UntagResourceInput, + UpdateGsiAction, UpdateTableInput, UpdateTableOutput, UpdateTimeToLiveInput, + UpdateTimeToLiveOutput, }; pub use transaction::{ CancellationReason, ItemResponse, TransactConditionCheck, TransactDelete, TransactGet, diff --git a/crates/core/src/types/table.rs b/crates/core/src/types/table.rs index eacba35b..dc2f3cf5 100755 --- a/crates/core/src/types/table.rs +++ b/crates/core/src/types/table.rs @@ -122,6 +122,17 @@ pub struct SseDescription { pub status: String, #[serde(rename = "SSEType", skip_serializing_if = "Option::is_none")] pub sse_type: Option, + #[serde(rename = "KMSMasterKeyArn", skip_serializing_if = "Option::is_none")] + pub kms_master_key_arn: Option, +} + +/// On-demand throughput settings for a table or index. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct OnDemandThroughput { + #[serde(rename = "MaxReadRequestUnits")] + pub max_read_request_units: Option, + #[serde(rename = "MaxWriteRequestUnits")] + pub max_write_request_units: Option, } /// Summary of the table's billing mode and last update timestamp. @@ -264,6 +275,8 @@ pub struct TableDescription { pub sse_description: Option, #[serde(rename = "TableClassSummary", skip_serializing_if = "Option::is_none")] pub table_class_summary: Option, + #[serde(rename = "OnDemandThroughput", skip_serializing_if = "Option::is_none")] + pub on_demand_throughput: Option, } /// `CreateTable` request body. @@ -293,6 +306,8 @@ pub struct CreateTableInput { pub deletion_protection_enabled: Option, #[serde(rename = "TableClass")] pub table_class: Option, + #[serde(rename = "OnDemandThroughput")] + pub on_demand_throughput: Option, } /// `CreateTable` response body. @@ -414,6 +429,10 @@ pub struct UpdateTableInput { pub attribute_definitions: Option>, #[serde(rename = "StreamSpecification")] pub stream_specification: Option, + #[serde(rename = "TableClass")] + pub table_class: Option, + #[serde(rename = "OnDemandThroughput")] + pub on_demand_throughput: Option, } /// `UpdateTable` response body. @@ -544,3 +563,72 @@ pub struct DescribeLimitsOutput { #[serde(rename = "TableMaxWriteCapacityUnits")] pub table_max_write_capacity_units: i64, } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn on_demand_throughput_round_trips_json() { + let odt = OnDemandThroughput { + max_read_request_units: Some(100), + max_write_request_units: Some(50), + }; + let json = serde_json::to_value(&odt).unwrap(); + assert_eq!(json["MaxReadRequestUnits"], 100); + assert_eq!(json["MaxWriteRequestUnits"], 50); + let parsed: OnDemandThroughput = serde_json::from_value(json).unwrap(); + assert_eq!(parsed, odt); + } + + #[test] + fn on_demand_throughput_deserializes_from_input() { + let input_json = r#"{"MaxReadRequestUnits": 10, "MaxWriteRequestUnits": 5}"#; + let odt: OnDemandThroughput = serde_json::from_str(input_json).unwrap(); + assert_eq!(odt.max_read_request_units, Some(10)); + assert_eq!(odt.max_write_request_units, Some(5)); + } + + #[test] + fn sse_description_serializes_with_kms_arn() { + let sse = SseDescription { + status: "ENABLED".to_string(), + sse_type: Some(SseType::KMS), + kms_master_key_arn: Some("arn:aws:kms:us-east-1:123456789012:key/default".to_string()), + }; + let json = serde_json::to_value(&sse).unwrap(); + assert_eq!(json["Status"], "ENABLED"); + assert_eq!(json["SSEType"], "KMS"); + assert_eq!( + json["KMSMasterKeyArn"], + "arn:aws:kms:us-east-1:123456789012:key/default" + ); + } + + #[test] + fn sse_description_omits_none_fields() { + let sse = SseDescription { + status: "ENABLED".to_string(), + sse_type: None, + kms_master_key_arn: None, + }; + let json = serde_json::to_value(&sse).unwrap(); + assert_eq!(json["Status"], "ENABLED"); + assert!(json.get("SSEType").is_none()); + assert!(json.get("KMSMasterKeyArn").is_none()); + } + + #[test] + fn create_table_input_deserializes_on_demand_throughput() { + let json = r#"{ + "TableName": "T", + "KeySchema": [{"AttributeName": "pk", "KeyType": "HASH"}], + "AttributeDefinitions": [{"AttributeName": "pk", "AttributeType": "S"}], + "OnDemandThroughput": {"MaxReadRequestUnits": 10, "MaxWriteRequestUnits": 5} + }"#; + let input: CreateTableInput = serde_json::from_str(json).unwrap(); + let odt = input.on_demand_throughput.unwrap(); + assert_eq!(odt.max_read_request_units, Some(10)); + assert_eq!(odt.max_write_request_units, Some(5)); + } +} diff --git a/crates/core/src/validation/mod.rs b/crates/core/src/validation/mod.rs index 33685ea3..75afa0cd 100755 --- a/crates/core/src/validation/mod.rs +++ b/crates/core/src/validation/mod.rs @@ -961,6 +961,7 @@ mod tests { tags: None, deletion_protection_enabled: None, table_class: None, + on_demand_throughput: None, } } diff --git a/crates/engine/src/import_export.rs b/crates/engine/src/import_export.rs index a189f113..ec3788ce 100755 --- a/crates/engine/src/import_export.rs +++ b/crates/engine/src/import_export.rs @@ -309,6 +309,7 @@ fn create_table_input_from_params(tcp: &TableCreationParameters) -> CreateTableI tags: None, deletion_protection_enabled: None, table_class: None, + on_demand_throughput: None, } } diff --git a/crates/engine/src/update_table.rs b/crates/engine/src/update_table.rs index df3d31cc..ec8b291d 100755 --- a/crates/engine/src/update_table.rs +++ b/crates/engine/src/update_table.rs @@ -45,6 +45,8 @@ pub async fn handle_update_table( && input.provisioned_throughput.is_none() && input.deletion_protection_enabled.is_none() && input.stream_specification.is_none() + && input.table_class.is_none() + && input.on_demand_throughput.is_none() && !has_gsi_updates { return Err(DynamoDbError::ValidationException( diff --git a/crates/storage-postgres/migrations/001_schema.sql b/crates/storage-postgres/migrations/001_schema.sql index 80549c7f..ad0be3c8 100644 --- a/crates/storage-postgres/migrations/001_schema.sql +++ b/crates/storage-postgres/migrations/001_schema.sql @@ -32,6 +32,9 @@ CREATE TABLE IF NOT EXISTS tables ( status_transition_at TIMESTAMPTZ, stream_label TEXT, ttl_index_ready BOOLEAN NOT NULL DEFAULT FALSE, + table_class TEXT, + sse_specification JSONB, + on_demand_throughput JSONB, PRIMARY KEY (account_id, table_name), CONSTRAINT tables_table_id_unique UNIQUE (table_id) ); diff --git a/crates/storage-postgres/src/backup_engine.rs b/crates/storage-postgres/src/backup_engine.rs index ba14d401..7ff08715 100755 --- a/crates/storage-postgres/src/backup_engine.rs +++ b/crates/storage-postgres/src/backup_engine.rs @@ -415,6 +415,7 @@ impl BackupEngine for PostgresEngine { deletion_protection_enabled: None, sse_specification: None, table_class: None, + on_demand_throughput: None, }; let desc = self.create_table(&account_id, create_input).await?; diff --git a/crates/storage-postgres/src/create_table.rs b/crates/storage-postgres/src/create_table.rs index 65f9e460..62613963 100755 --- a/crates/storage-postgres/src/create_table.rs +++ b/crates/storage-postgres/src/create_table.rs @@ -5,7 +5,7 @@ use extenddb_core::types::{ BillingMode, BillingModeSummary, CreateTableInput, GsiDescription, LsiDescription, - ProvisionedThroughputDescription, TableDescription, TableStatus, + ProvisionedThroughputDescription, SseDescription, SseType, TableDescription, TableStatus, }; use extenddb_storage::error::StorageError; use extenddb_storage::util::{index_arn, stream_arn, table_arn}; @@ -45,6 +45,13 @@ impl PostgresEngine { .transpose() .map_err(|e| StorageError::Internal(e.to_string()))?; let deletion_protection = input.deletion_protection_enabled.unwrap_or(false); + let sse_spec_json = input.sse_specification.as_ref().cloned(); + let on_demand_json = input + .on_demand_throughput + .as_ref() + .map(serde_json::to_value) + .transpose() + .map_err(|e| StorageError::Internal(e.to_string()))?; let mut tx = self .pool @@ -68,7 +75,7 @@ impl PostgresEngine { (account_id, table_name, key_schema, attribute_definitions, billing_mode, provisioned_throughput, stream_specification, table_status, creation_date_time, table_arn, table_id, deletion_protection_enabled, - status_transition_at) + status_transition_at, table_class, sse_specification, on_demand_throughput) VALUES ($1, $2, $3, $4, $5, $6, $7, CASE WHEN (SELECT secs FROM delay) = 0 THEN 'ACTIVE' ELSE 'CREATING' END, @@ -76,7 +83,8 @@ impl PostgresEngine { CASE WHEN (SELECT secs FROM delay) = 0 THEN NULL ELSE NOW() + make_interval(secs => (SELECT secs FROM delay)) - END) + END, + $11, $12, $13) RETURNING EXTRACT(EPOCH FROM creation_date_time)::FLOAT8, table_status", ) .bind(account_id) @@ -89,6 +97,9 @@ impl PostgresEngine { .bind(&table_arn) .bind(&table_id) .bind(deletion_protection) + .bind(&input.table_class) + .bind(&sse_spec_json) + .bind(&on_demand_json) .fetch_one(&mut *tx) .await .map_err(|e| match &e { @@ -386,8 +397,29 @@ impl PostgresEngine { latest_stream_arn, latest_stream_label: stream_label, deletion_protection_enabled: input.deletion_protection_enabled.unwrap_or(false), - sse_description: None, - table_class_summary: None, + sse_description: input.sse_specification.as_ref().and_then(|spec| { + let enabled = spec + .get("Enabled") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + if enabled { + Some(SseDescription { + status: "ENABLED".to_string(), + sse_type: Some(SseType::KMS), + kms_master_key_arn: Some(format!( + "arn:aws:kms:{}:{}:key/default", + self.region, account_id + )), + }) + } else { + None + } + }), + table_class_summary: input + .table_class + .as_ref() + .map(|tc| serde_json::json!({ "TableClass": tc })), + on_demand_throughput: input.on_demand_throughput, }) } } diff --git a/crates/storage-postgres/src/delete_table.rs b/crates/storage-postgres/src/delete_table.rs index b89a5326..6da8a4ca 100755 --- a/crates/storage-postgres/src/delete_table.rs +++ b/crates/storage-postgres/src/delete_table.rs @@ -29,7 +29,8 @@ impl PostgresEngine { provisioned_throughput, stream_specification, table_status, EXTRACT(EPOCH FROM creation_date_time)::FLOAT8 as creation_epoch, table_size_bytes, item_count, table_arn, table_id, - deletion_protection_enabled, stream_label + deletion_protection_enabled, stream_label, + table_class, sse_specification, on_demand_throughput FROM tables WHERE account_id = $1 AND table_name = $2 AND table_status IN ('ACTIVE', 'CREATING') FOR UPDATE", ) diff --git a/crates/storage-postgres/src/table_helpers.rs b/crates/storage-postgres/src/table_helpers.rs index ef88d332..0d2d4b55 100755 --- a/crates/storage-postgres/src/table_helpers.rs +++ b/crates/storage-postgres/src/table_helpers.rs @@ -5,7 +5,8 @@ use extenddb_core::types::{ AttributeDefinition, BillingMode, BillingModeSummary, GsiDescription, KeySchemaElement, - LsiDescription, Projection, ProvisionedThroughputDescription, TableDescription, TableStatus, + LsiDescription, Projection, ProvisionedThroughputDescription, SseDescription, SseType, + TableDescription, TableStatus, }; use extenddb_storage::error::StorageError; use extenddb_storage::util::{index_arn, stream_arn}; @@ -30,6 +31,9 @@ pub(crate) struct TableRow { pub table_id: String, pub deletion_protection_enabled: bool, pub stream_label: Option, + pub table_class: Option, + pub sse_specification: Option, + pub on_demand_throughput: Option, } /// Row type for index metadata queries. @@ -149,7 +153,8 @@ impl PostgresEngine { provisioned_throughput, stream_specification, table_status, EXTRACT(EPOCH FROM creation_date_time)::FLOAT8 as creation_epoch, table_size_bytes, item_count, table_arn, table_id, - deletion_protection_enabled, stream_label + deletion_protection_enabled, stream_label, + table_class, sse_specification, on_demand_throughput FROM tables WHERE account_id = $1 AND table_name = $2", ) .bind(account_id) @@ -322,8 +327,31 @@ impl PostgresEngine { latest_stream_arn, latest_stream_label: row.stream_label, deletion_protection_enabled: row.deletion_protection_enabled, - sse_description: None, - table_class_summary: None, + sse_description: row.sse_specification.as_ref().and_then(|spec| { + let enabled = spec + .get("Enabled") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + if enabled { + Some(SseDescription { + status: "ENABLED".to_string(), + sse_type: Some(SseType::KMS), + kms_master_key_arn: Some(format!( + "arn:aws:kms:{}:{}:key/default", + self.region, account_id + )), + }) + } else { + None + } + }), + table_class_summary: row + .table_class + .as_ref() + .map(|tc| serde_json::json!({ "TableClass": tc })), + on_demand_throughput: row + .on_demand_throughput + .and_then(|v| serde_json::from_value(v).ok()), }) } } diff --git a/crates/storage-postgres/src/update_table.rs b/crates/storage-postgres/src/update_table.rs index 8fbb76b5..2752b56b 100755 --- a/crates/storage-postgres/src/update_table.rs +++ b/crates/storage-postgres/src/update_table.rs @@ -201,6 +201,35 @@ impl PostgresEngine { } } + // Apply table class change. + if let Some(tc) = &input.table_class { + sqlx::query( + "UPDATE tables SET table_class = $1 WHERE account_id = $2 AND table_name = $3", + ) + .bind(tc) + .bind(account_id) + .bind(&input.table_name) + .execute(&mut *tx) + .await + .map_err(|e| StorageError::Internal(e.to_string()))?; + } + + // Apply on-demand throughput change. + if let Some(odt) = &input.on_demand_throughput { + let odt_json = + serde_json::to_value(odt).map_err(|e| StorageError::Internal(e.to_string()))?; + sqlx::query( + "UPDATE tables SET on_demand_throughput = $1 \ + WHERE account_id = $2 AND table_name = $3", + ) + .bind(&odt_json) + .bind(account_id) + .bind(&input.table_name) + .execute(&mut *tx) + .await + .map_err(|e| StorageError::Internal(e.to_string()))?; + } + // Apply GSI updates (create/delete). let mut created_index_ids: Vec = Vec::new(); let mut deleted_index_ids: Vec = Vec::new(); diff --git a/tests/test_config_fields.py b/tests/test_config_fields.py new file mode 100644 index 00000000..3549927b --- /dev/null +++ b/tests/test_config_fields.py @@ -0,0 +1,97 @@ +# Copyright 2026 ExtendDB contributors +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for TableClass, SSESpecification, and OnDemandThroughput round-tripping.""" + +from __future__ import annotations + +import pytest +from conftest import wait_for_active + + +class TestTableClass: + """TableClass field round-trips through CreateTable and UpdateTable.""" + + def test_table_class_infrequent_access(self, create_and_cleanup_table, dynamodb_client): + result = create_and_cleanup_table( + TableClass="STANDARD_INFREQUENT_ACCESS", + ) + desc = dynamodb_client.describe_table( + TableName=result["TableDescription"]["TableName"] + ) + assert desc["Table"]["TableClassSummary"]["TableClass"] == "STANDARD_INFREQUENT_ACCESS" + + def test_table_class_default_standard(self, create_and_cleanup_table, dynamodb_client): + result = create_and_cleanup_table() + desc = dynamodb_client.describe_table( + TableName=result["TableDescription"]["TableName"] + ) + # STANDARD may be omitted entirely or reported explicitly + tc = desc["Table"].get("TableClassSummary", {}).get("TableClass", "STANDARD") + assert tc == "STANDARD" + + def test_update_table_class(self, create_and_cleanup_table, dynamodb_client): + result = create_and_cleanup_table() + table_name = result["TableDescription"]["TableName"] + dynamodb_client.update_table( + TableName=table_name, + TableClass="STANDARD_INFREQUENT_ACCESS", + ) + wait_for_active(dynamodb_client, table_name) + desc = dynamodb_client.describe_table(TableName=table_name) + assert desc["Table"]["TableClassSummary"]["TableClass"] == "STANDARD_INFREQUENT_ACCESS" + + +class TestSSESpecification: + """SSESpecification round-trips as SSEDescription in DescribeTable.""" + + def test_sse_enabled_round_trips(self, create_and_cleanup_table, dynamodb_client): + result = create_and_cleanup_table( + SSESpecification={"Enabled": True}, + ) + desc = dynamodb_client.describe_table( + TableName=result["TableDescription"]["TableName"] + ) + sse = desc["Table"]["SSEDescription"] + assert sse["Status"] == "ENABLED" + assert sse["SSEType"] == "KMS" + assert "arn:aws:kms:" in sse["KMSMasterKeyArn"] + + +class TestOnDemandThroughput: + """OnDemandThroughput round-trips through CreateTable and UpdateTable.""" + + def test_on_demand_throughput_create(self, create_and_cleanup_table, dynamodb_client): + result = create_and_cleanup_table( + OnDemandThroughput={ + "MaxReadRequestUnits": 10, + "MaxWriteRequestUnits": 5, + }, + ) + desc = dynamodb_client.describe_table( + TableName=result["TableDescription"]["TableName"] + ) + odt = desc["Table"]["OnDemandThroughput"] + assert odt["MaxReadRequestUnits"] == 10 + assert odt["MaxWriteRequestUnits"] == 5 + + def test_update_on_demand_throughput(self, create_and_cleanup_table, dynamodb_client): + result = create_and_cleanup_table( + OnDemandThroughput={ + "MaxReadRequestUnits": 10, + "MaxWriteRequestUnits": 5, + }, + ) + table_name = result["TableDescription"]["TableName"] + dynamodb_client.update_table( + TableName=table_name, + OnDemandThroughput={ + "MaxReadRequestUnits": 20, + "MaxWriteRequestUnits": 15, + }, + ) + wait_for_active(dynamodb_client, table_name) + desc = dynamodb_client.describe_table(TableName=table_name) + odt = desc["Table"]["OnDemandThroughput"] + assert odt["MaxReadRequestUnits"] == 20 + assert odt["MaxWriteRequestUnits"] == 15