diff --git a/app/adapters/dynamodb_unit_of_work.py b/app/adapters/dynamodb_unit_of_work.py index bf14676..d0d4909 100644 --- a/app/adapters/dynamodb_unit_of_work.py +++ b/app/adapters/dynamodb_unit_of_work.py @@ -4,9 +4,12 @@ from mypy_boto3_dynamodb import client from app.adapters.internal import dynamodb_base +from app.domain.exceptions import repository_exception from app.domain.model import product, product_version from app.domain.ports import unit_of_work +DYNAMODB_TRANSACTION_LIMIT = 25 + class DBPrefix(enum.Enum): PRODUCT = "PRODUCT" @@ -56,8 +59,17 @@ def update_attributes(self, product_id: str, **kwargs) -> None: ) def delete(self, product_id: str) -> None: - key = self.generate_product_key(product_id) - self.delete_generic_item(key=key) + """Deletes all records with the given product_id as partition key (product + versions).""" + pk_value = f"{DBPrefix.PRODUCT.value}#{product_id}" + request = self._create_query_by_pk_request(pk_value) + items = self._context.query_items(request) + if len(items) > DYNAMODB_TRANSACTION_LIMIT: + raise repository_exception.RepositoryException( + f"Cannot delete: {len(items)} items exceed DynamoDB transaction limit of {DYNAMODB_TRANSACTION_LIMIT}." + ) + for item in items: + key = {"PK": item["PK"], "SK": item["SK"]} + self.delete_generic_item(key=key) @staticmethod def generate_product_key(product_id: str) -> dict: diff --git a/app/adapters/internal/dynamodb_base.py b/app/adapters/internal/dynamodb_base.py index cc6b8bd..2c74847 100644 --- a/app/adapters/internal/dynamodb_base.py +++ b/app/adapters/internal/dynamodb_base.py @@ -36,6 +36,20 @@ def get_generic_item(self, request: dict) -> Any: return item["Item"] if "Item" in item else None + def query_items(self, query: dict) -> List[dict]: + """ + Queries all items with the given partition key value. + Returns a list of items (each with PK and SK in DynamoDB format). + """ + items: List[dict] = [] + while True: + response = self._dynamo_db_client.query(**query) + items.extend(response.get("Items", [])) + if "LastEvaluatedKey" not in response: + break + query["ExclusiveStartKey"] = response["LastEvaluatedKey"] + return items + class DynamoDBRepository: """Generic DynamoDB repository.""" @@ -86,3 +100,10 @@ def _create_get_request(self, key: dict) -> dict: def _create_delete_modifier(self, key: dict) -> dict: return {"Delete": {"TableName": self._table_name, "Key": key}} + + def _create_query_by_pk_request(self, pk_value: str) -> dict: + return { + "TableName": self._table_name, + "KeyConditionExpression": "PK = :pk", + "ExpressionAttributeValues": {":pk": pk_value}, + } diff --git a/app/adapters/tests/test_dynamodb_unit_of_work.py b/app/adapters/tests/test_dynamodb_unit_of_work.py index a75d5ff..eb2efdc 100644 --- a/app/adapters/tests/test_dynamodb_unit_of_work.py +++ b/app/adapters/tests/test_dynamodb_unit_of_work.py @@ -7,7 +7,7 @@ import pytest from app.adapters import dynamodb_unit_of_work -from app.domain.model import product +from app.domain.model import product, product_version TEST_TABLE_NAME = "test-table" @@ -168,3 +168,60 @@ def test_delete_and_commit_should_delete_product(mock_dynamodb): product_from_db = unit_of_work_readonly.products.get(new_product_id) assertpy.assert_that(product_from_db).is_none() + + +def test_delete_should_remove_product_and_all_versions(mock_dynamodb): + """Delete removes all records with the product_id as partition key (product + versions).""" + # Arrange + unit_of_work = dynamodb_unit_of_work.DynamoDBUnitOfWork( + table_name=TEST_TABLE_NAME, dynamodb_client=mock_dynamodb.meta.client + ) + unit_of_work_readonly = dynamodb_unit_of_work.DynamoDBUnitOfWork( + table_name=TEST_TABLE_NAME, dynamodb_client=mock_dynamodb.meta.client + ) + current_time = datetime.datetime.now(datetime.timezone.utc).isoformat() + + new_product_id = str(uuid.uuid4()) + new_product = product.Product( + id=new_product_id, + name="test-name", + description="test-description", + createDate=current_time, + lastUpdateDate=current_time, + ) + version_1 = product_version.ProductVersion( + id="v1", + name="Version 1", + version="1.0.0", + createDate=current_time, + ) + version_2 = product_version.ProductVersion( + id="v2", + name="Version 2", + version="2.0.0", + createDate=current_time, + ) + with unit_of_work: + unit_of_work.products.add(new_product) + unit_of_work.product_versions.add(new_product_id, version_1) + unit_of_work.product_versions.add(new_product_id, version_2) + unit_of_work.commit() + + # Act + with unit_of_work: + unit_of_work.products.delete(new_product_id) + unit_of_work.commit() + + # Assert - product and all versions should be gone + with unit_of_work_readonly: + product_from_db = unit_of_work_readonly.products.get(new_product_id) + version_1_from_db = unit_of_work_readonly.product_versions.get( + new_product_id, "v1" + ) + version_2_from_db = unit_of_work_readonly.product_versions.get( + new_product_id, "v2" + ) + + assertpy.assert_that(product_from_db).is_none() + assertpy.assert_that(version_1_from_db).is_none() + assertpy.assert_that(version_2_from_db).is_none() diff --git a/app/domain/command_handlers/add_product_version_command_handler.py b/app/domain/command_handlers/add_product_version_command_handler.py new file mode 100644 index 0000000..8f479f3 --- /dev/null +++ b/app/domain/command_handlers/add_product_version_command_handler.py @@ -0,0 +1,33 @@ +import uuid +from datetime import datetime, timezone +from typing import Optional + +from app.domain.commands import add_product_version_command +from app.domain.model import product, product_version +from app.domain.ports import unit_of_work + + +def handle_add_product_version_command( + command: add_product_version_command.AddProductVersionCommand, + unit_of_work: unit_of_work.UnitOfWork, +) -> Optional[str]: + + with unit_of_work: + product_obj = unit_of_work.products.get(product_id=command.product_id) + if not product_obj: + return None + + current_time = datetime.now(timezone.utc).isoformat() + id = str(uuid.uuid4()) + + version_obj = product_version.ProductVersion( + id=id, + name=command.name, + version=command.version, + createDate=current_time, + ) + + unit_of_work.product_versions.add(command.product_id, version_obj) + unit_of_work.commit() + + return id diff --git a/app/domain/command_handlers/create_product_command_handler.py b/app/domain/command_handlers/create_product_command_handler.py index 1dcaf21..f9d2d4c 100644 --- a/app/domain/command_handlers/create_product_command_handler.py +++ b/app/domain/command_handlers/create_product_command_handler.py @@ -10,6 +10,7 @@ def handle_create_product_command( command: create_product_command.CreateProductCommand, unit_of_work: unit_of_work.UnitOfWork, ) -> str: + current_time = datetime.now(timezone.utc).isoformat() id = str(uuid.uuid4()) diff --git a/app/domain/command_handlers/get_product_command_handler.py b/app/domain/command_handlers/get_product_command_handler.py index 1984494..da36832 100644 --- a/app/domain/command_handlers/get_product_command_handler.py +++ b/app/domain/command_handlers/get_product_command_handler.py @@ -10,6 +10,6 @@ def handle_get_product_command( query_service: products_query_service.ProductsQueryService, ) -> Optional[product.Product]: - product_obj = query_service.get_product_by_id(command.id) + product_obj = query_service.get_product_by_id(product_id=command.id) return product_obj diff --git a/app/domain/commands/add_product_version_command.py b/app/domain/commands/add_product_version_command.py new file mode 100644 index 0000000..a159adb --- /dev/null +++ b/app/domain/commands/add_product_version_command.py @@ -0,0 +1,9 @@ +from typing import Optional + +from pydantic import BaseModel + + +class AddProductVersionCommand(BaseModel): + product_id: str + name: Optional[str] + version: str diff --git a/app/domain/tests/test_command_handlers.py b/app/domain/tests/test_command_handlers.py index a4fd553..2b3ffa9 100644 --- a/app/domain/tests/test_command_handlers.py +++ b/app/domain/tests/test_command_handlers.py @@ -9,6 +9,7 @@ create_product_command_handler, delete_product_command_handler, update_product_command_handler, + add_product_version_command_handler, ) from app.domain.commands import ( get_product_command, @@ -16,6 +17,7 @@ create_product_command, delete_product_command, update_product_command, + add_product_version_command, ) from app.domain.ports import products_query_service, unit_of_work @@ -64,14 +66,21 @@ def test_list_products_should_query_from_repository(): ) -def test_create_product_should_store_in_repository(): - # Arrange - mock_unit_of_work = unittest.mock.create_autospec( - spec=unit_of_work.UnitOfWork, instance=True - ) - mock_unit_of_work.products = unittest.mock.create_autospec( +def _create_mock_unit_of_work(): + """Create a mock UnitOfWork that works correctly as a context manager.""" + mock_uow = unittest.mock.MagicMock() + mock_uow.products = unittest.mock.create_autospec( spec=unit_of_work.ProductsRepository, instance=True ) + mock_uow.commit = unittest.mock.Mock() + mock_uow.__enter__ = unittest.mock.Mock(return_value=mock_uow) + mock_uow.__exit__ = unittest.mock.Mock(return_value=None) + return mock_uow + + +def test_create_product_should_store_in_repository(): + # Arrange + mock_unit_of_work = _create_mock_unit_of_work() command = create_product_command.CreateProductCommand( name="Test Product", @@ -93,12 +102,7 @@ def test_create_product_should_store_in_repository(): def test_update_product_should_only_update_specified_property(): # Arrange - mock_unit_of_work = unittest.mock.create_autospec( - spec=unit_of_work.UnitOfWork, instance=True - ) - mock_unit_of_work.products = unittest.mock.create_autospec( - spec=unit_of_work.ProductsRepository, instance=True - ) + mock_unit_of_work = _create_mock_unit_of_work() # Update only the description product_id = str(uuid.uuid4()) @@ -122,12 +126,7 @@ def test_update_product_should_only_update_specified_property(): def test_delete_product_should_delete_from_repository(): # Arrange - mock_unit_of_work = unittest.mock.create_autospec( - spec=unit_of_work.UnitOfWork, instance=True - ) - mock_unit_of_work.products = unittest.mock.create_autospec( - spec=unit_of_work.ProductsRepository, instance=True - ) + mock_unit_of_work = _create_mock_unit_of_work() product_id = str(uuid.uuid4()) command = delete_product_command.DeleteProductCommand(id=product_id) @@ -144,3 +143,35 @@ def test_delete_product_should_delete_from_repository(): ] assertpy.assert_that(deleted_product_id).is_equal_to(product_id) + + +def test_add_product_version_should_store_in_repository(): + # Arrange + mock_unit_of_work = _create_mock_unit_of_work() + mock_unit_of_work.product_versions = unittest.mock.create_autospec( + spec=unit_of_work.ProductVersionsRepository, instance=True + ) + mock_unit_of_work.products.get.return_value = unittest.mock.MagicMock() + + command = add_product_version_command.AddProductVersionCommand( + product_id="Test Product ID", + name="Test Product", + version="Test Version", + ) + + # Act + add_product_version_command_handler.handle_add_product_version_command( + command=command, unit_of_work=mock_unit_of_work + ) + + # Assert + mock_unit_of_work.products.get.assert_called_once_with( + product_id="Test Product ID" + ) + product_id = mock_unit_of_work.products.get.call_args.kwargs["product_id"] + version = mock_unit_of_work.product_versions.add.call_args.args[1] + mock_unit_of_work.commit.assert_called_once() + + assertpy.assert_that(product_id).is_equal_to("Test Product ID") + assertpy.assert_that(version.name).is_equal_to("Test Product") + assertpy.assert_that(version.version).is_equal_to("Test Version") diff --git a/app/entrypoints/api/handler.py b/app/entrypoints/api/handler.py index e43050f..061a5df 100644 --- a/app/entrypoints/api/handler.py +++ b/app/entrypoints/api/handler.py @@ -10,6 +10,7 @@ create_product_command_handler, delete_product_command_handler, update_product_command_handler, + add_product_version_command_handler, ) from app.domain.commands import ( get_product_command, @@ -17,6 +18,7 @@ create_product_command, delete_product_command, update_product_command, + add_product_version_command, ) from app.domain.exceptions.domain_exception import DomainException from app.entrypoints.api import config @@ -147,6 +149,30 @@ def delete_product( return response.dict() +@tracer.capture_method +@app.post("/products//versions") +@utils.parse_event(model=api_model.AddProductVersionRequest, app_context=app) +def add_product_version( + request: api_model.AddProductVersionRequest, id: str +) -> api_model.CreateProductResponse: + """Adds a version to a product.""" + + product = add_product_version_command_handler.handle_add_product_version_command( + command=add_product_version_command.AddProductVersionCommand( + product_id=id, + name=request.name, + version=request.version, + ), + unit_of_work=unit_of_work, + ) + + if not product: + raise DomainException(f"Could not locate product with id: {id}.") + + response = api_model.CreateProductResponse(id=id) + return response.dict() + + @tracer.capture_lambda_handler @logger.inject_lambda_context(log_event=True) @data_classes.event_source( diff --git a/app/entrypoints/api/model/api_model.py b/app/entrypoints/api/model/api_model.py index 0cdd68c..919875c 100644 --- a/app/entrypoints/api/model/api_model.py +++ b/app/entrypoints/api/model/api_model.py @@ -44,3 +44,7 @@ class Product(BaseModel): class ListProductsResponse(BaseModel): nextToken: Optional[Dict[str, Any]] = Field(title="LastEvaluatedKey token") products: List[Product] = Field(..., title="Products") + +class AddProductVersionRequest(BaseModel): + name: Optional[str] = Field(title="Name") + version: str = Field(..., title="Version") diff --git a/app/entrypoints/api/tests/test_api_handler.py b/app/entrypoints/api/tests/test_api_handler.py index 4c85e86..3b7be68 100644 --- a/app/entrypoints/api/tests/test_api_handler.py +++ b/app/entrypoints/api/tests/test_api_handler.py @@ -1,6 +1,7 @@ import json import unittest from dataclasses import dataclass +from unittest.mock import patch import assertpy import pytest @@ -10,6 +11,8 @@ create_product_command_handler, delete_product_command_handler, update_product_command_handler, + get_product_command_handler, + list_products_command_handler, ) from app.domain.ports import products_query_service from app.entrypoints.api import handler @@ -28,7 +31,8 @@ class LambdaContext: return LambdaContext() -def test_create_product(lambda_context): +@patch.object(create_product_command_handler, "handle_create_product_command") +def test_create_product(create_product_func_mock, lambda_context): # Arrange name = "TestName" description = "Test description" @@ -45,13 +49,6 @@ def test_create_product(lambda_context): } ) - create_product_func_mock = unittest.mock.create_autospec( - spec=create_product_command_handler.handle_create_product_command - ) - handler.create_product_command_handler.handle_create_product_command = ( - create_product_func_mock - ) - # Act handler.handler(minimal_event, lambda_context) @@ -62,7 +59,8 @@ def test_create_product(lambda_context): assertpy.assert_that(command.description).is_equal_to(description) -def test_update_product(lambda_context): +@patch.object(update_product_command_handler, "handle_update_product_command") +def test_update_product(update_product_func_mock, lambda_context): # Arrange id = "test-id" description = "Test description" @@ -79,13 +77,6 @@ def test_update_product(lambda_context): } ) - update_product_func_mock = unittest.mock.create_autospec( - spec=update_product_command_handler.handle_update_product_command - ) - handler.update_product_command_handler.handle_update_product_command = ( - update_product_func_mock - ) - # Act handler.handler(minimal_event, lambda_context) @@ -96,7 +87,8 @@ def test_update_product(lambda_context): assertpy.assert_that(command.description).is_equal_to(description) -def test_delete_product(lambda_context): +@patch.object(delete_product_command_handler, "handle_delete_product_command") +def test_delete_product(delete_product_func_mock, lambda_context): # Arrange id = "test-id" @@ -110,13 +102,6 @@ def test_delete_product(lambda_context): } ) - delete_product_func_mock = unittest.mock.create_autospec( - spec=delete_product_command_handler.handle_delete_product_command - ) - handler.delete_product_command_handler.handle_delete_product_command = ( - delete_product_func_mock - ) - # Act handler.handler(minimal_event, lambda_context) @@ -126,7 +111,8 @@ def test_delete_product(lambda_context): assertpy.assert_that(command.id).is_equal_to(id) -def test_get_product(lambda_context): +@patch.object(get_product_command_handler, "handle_get_product_command") +def test_get_product(get_product_func_mock, lambda_context): # Arrange id = "test-id" minimal_event = api_gateway_proxy_event.APIGatewayProxyEvent( @@ -139,21 +125,17 @@ def test_get_product(lambda_context): } ) - mock_query_service = unittest.mock.create_autospec( - spec=products_query_service.ProductsQueryService - ) - handler.products_query_service = mock_query_service - # Act handler.handler(minimal_event, lambda_context) # Assert - mock_query_service.get_product_by_id.assert_called_once() - got_product_id = mock_query_service.get_product_by_id.call_args.kwargs["product_id"] - assertpy.assert_that(got_product_id).is_equal_to(id) + get_product_func_mock.assert_called_once() + command = get_product_func_mock.call_args.kwargs["command"] + assertpy.assert_that(command.id).is_equal_to(id) -def test_list_products(lambda_context): +@patch.object(list_products_command_handler, "handle_list_products_command") +def test_list_products(list_products_func_mock, lambda_context): # Arrange page_size = 10 minimal_event = api_gateway_proxy_event.APIGatewayProxyEvent( @@ -167,15 +149,10 @@ def test_list_products(lambda_context): } ) - mock_query_service = unittest.mock.create_autospec( - spec=products_query_service.ProductsQueryService - ) - handler.products_query_service = mock_query_service - # Act handler.handler(minimal_event, lambda_context) # Assert - mock_query_service.list_products.assert_called_once() - got_page_size = mock_query_service.list_products.call_args.kwargs["page_size"] - assertpy.assert_that(got_page_size).is_equal_to(page_size) + list_products_func_mock.assert_called_once() + command = list_products_func_mock.call_args.kwargs["command"] + assertpy.assert_that(command.page_size).is_equal_to(page_size) diff --git a/infra/simple_crud_app_stack.py b/infra/simple_crud_app_stack.py index d5ad6ae..70cb431 100644 --- a/infra/simple_crud_app_stack.py +++ b/infra/simple_crud_app_stack.py @@ -77,11 +77,18 @@ def __init__( products_id.add_method( "DELETE", authorization_type=aws_apigateway.AuthorizationType.IAM ) + products_id_versions = products_id.add_resource("versions") + products_id_versions.add_method( + "POST", authorization_type=aws_apigateway.AuthorizationType.IAM + ) products.add_cors_preflight(allow_origins=["*"], allow_methods=["GET", "POST"]) products_id.add_cors_preflight( allow_origins=["*"], allow_methods=["GET", "PUT", "DELETE"] ) + products_id_versions.add_cors_preflight( + allow_origins=["*"], allow_methods=["POST"] + ) cdk_nag.NagSuppressions.add_resource_suppressions_by_path( stack=self, @@ -103,6 +110,16 @@ def __init__( ), ], ) + cdk_nag.NagSuppressions.add_resource_suppressions_by_path( + stack=self, + path='/SimpleCrudAppStack/SimpleCrudAppApi/SimpleCrudAppRestApi/Default/products/{id}/versions/OPTIONS/Resource', + suppressions=[ + cdk_nag.NagPackSuppression( + id="AwsSolutions-APIG4", + reason="OPTIONS methods have no authorization.", + ), + ], + ) cdk_nag.NagSuppressions.add_resource_suppressions( construct=self._api,