diff --git a/api/gen_ai/tests/test_math.py b/api/gen_ai/tests/test_math.py new file mode 100644 index 0000000000..619b5ab2ab --- /dev/null +++ b/api/gen_ai/tests/test_math.py @@ -0,0 +1,191 @@ +import math +import unittest +from datetime import datetime + +import pytest +from api.gen_ai.math import ( + Calculator, + DataProcessor, + StringManipulator, + add, + divide, + factorial, + fibonacci, + is_prime, + merge_dicts, + parse_date, + safe_list_access, + subtract, +) + + +class TestBasicArithmeticFunctions(unittest.TestCase): + def test_add(self): + self.assertEqual(add(2, 3), 5) + self.assertEqual(add(-1, 1), 0) + self.assertEqual(add(0, 0), 0) + self.assertEqual(add(2.5, 3.5), 6.0) + + def test_subtract(self): + self.assertEqual(subtract(5, 3), 2) + self.assertEqual(subtract(1, 1), 0) + self.assertEqual(subtract(0, 5), -5) + self.assertEqual(subtract(5.5, 2.5), 3.0) + + def test_divide(self): + self.assertEqual(divide(6, 3), 2) + self.assertEqual(divide(5, 2), 2.5) + self.assertEqual(divide(0, 5), 0) + self.assertEqual(divide(-6, 3), -2) + + def test_divide_by_zero(self): + with self.assertRaises(ValueError) as context: + divide(5, 0) + self.assertEqual(str(context.exception), "Division by zero is not allowed") + + +class TestFactorial(unittest.TestCase): + def test_factorial_positive(self): + self.assertEqual(factorial(0), 1) + self.assertEqual(factorial(1), 1) + self.assertEqual(factorial(5), 120) + self.assertEqual(factorial(10), 3628800) + + def test_factorial_negative(self): + with self.assertRaises(ValueError) as context: + factorial(-1) + self.assertEqual(str(context.exception), "Negative numbers do not have factorials") + + +class TestIsPrime(unittest.TestCase): + def test_is_prime(self): + self.assertFalse(is_prime(0)) + self.assertFalse(is_prime(1)) + self.assertTrue(is_prime(2)) + self.assertTrue(is_prime(3)) + self.assertFalse(is_prime(4)) + self.assertTrue(is_prime(5)) + self.assertFalse(is_prime(6)) + self.assertTrue(is_prime(7)) + self.assertTrue(is_prime(11)) + self.assertTrue(is_prime(13)) + self.assertTrue(is_prime(17)) + self.assertTrue(is_prime(19)) + self.assertTrue(is_prime(97)) + self.assertFalse(is_prime(100)) + + +class TestFibonacci(unittest.TestCase): + def test_fibonacci(self): + self.assertEqual(fibonacci(0), 0) + self.assertEqual(fibonacci(1), 1) + self.assertEqual(fibonacci(2), 1) + self.assertEqual(fibonacci(3), 2) + self.assertEqual(fibonacci(4), 3) + self.assertEqual(fibonacci(5), 5) + self.assertEqual(fibonacci(6), 8) + self.assertEqual(fibonacci(10), 55) + + def test_fibonacci_negative(self): + with self.assertRaises(ValueError) as context: + fibonacci(-1) + self.assertEqual(str(context.exception), "n must be a non-negative integer") + + +class TestCalculator(unittest.TestCase): + def setUp(self): + self.calc = Calculator() + + def test_add(self): + self.assertEqual(self.calc.add(2, 3), 5) + + def test_subtract(self): + self.assertEqual(self.calc.subtract(5, 3), 2) + + def test_multiply(self): + self.assertEqual(self.calc.multiply(2, 3), 6) + + def test_divide(self): + self.assertEqual(self.calc.divide(6, 3), 2) + self.assertEqual(self.calc.divide(5, 2), 2.5) + + def test_divide_by_zero(self): + with self.assertRaises(ValueError) as context: + self.calc.divide(5, 0) + self.assertEqual(str(context.exception), "Cannot divide by zero") + + def test_memory_operations(self): + self.calc.store(5) + self.assertEqual(self.calc.recall(), 5) + self.calc.store(10) + self.assertEqual(self.calc.recall(), 10) + + +class TestStringManipulator(unittest.TestCase): + def test_reverse_string(self): + self.assertEqual(StringManipulator.reverse_string("hello"), "olleh") + self.assertEqual(StringManipulator.reverse_string(""), "") + self.assertEqual(StringManipulator.reverse_string("a"), "a") + self.assertEqual(StringManipulator.reverse_string("12345"), "54321") + + def test_is_palindrome(self): + self.assertTrue(StringManipulator.is_palindrome("racecar")) + self.assertTrue(StringManipulator.is_palindrome("A man, a plan, a canal: Panama")) + self.assertTrue(StringManipulator.is_palindrome("")) + self.assertTrue(StringManipulator.is_palindrome("a")) + self.assertFalse(StringManipulator.is_palindrome("hello")) + self.assertFalse(StringManipulator.is_palindrome("world")) + + +class TestDataProcessor(unittest.TestCase): + def test_get_mean(self): + dp = DataProcessor([1, 2, 3, 4, 5]) + self.assertEqual(dp.get_mean(), 3.0) + + def test_get_variance(self): + dp = DataProcessor([1, 2, 3, 4, 5]) + self.assertEqual(dp.get_variance(), 2.5) + + def test_normalize(self): + dp = DataProcessor([1, 2, 3, 4, 5]) + normalized = dp.normalize() + expected = [ + -1.264911064067352, + -0.6324555320336759, + 0.0, + 0.6324555320336759, + 1.264911064067352, + ] + for i in range(len(normalized)): + self.assertAlmostEqual(normalized[i], expected[i]) + + def test_empty_data(self): + with self.assertRaises(ValueError) as context: + dp = DataProcessor([]) + self.assertEqual(str(context.exception), "Data list cannot be empty") + + def test_variance_single_value(self): + dp = DataProcessor([5]) + with self.assertRaises(ValueError) as context: + dp.get_variance() + self.assertEqual(str(context.exception), "At least two data points are required to compute variance") + + +class TestUtilityFunctions(unittest.TestCase): + def test_parse_date(self): + self.assertEqual(parse_date("2023-01-01"), datetime(2023, 1, 1)) + with self.assertRaises(ValueError): + parse_date("01/01/2023") + + def test_safe_list_access(self): + test_list = [1, 2, 3] + self.assertEqual(safe_list_access(test_list, 1), 2) + self.assertEqual(safe_list_access(test_list, 5), None) + self.assertEqual(safe_list_access(test_list, 5, "default"), "default") + + def test_merge_dicts(self): + dict1 = {"a": 1, "b": {"c": 2, "d": 3}} + dict2 = {"b": {"e": 4}, "f": 5} + result = merge_dicts(dict1, dict2) + expected = {"a": 1, "b": {"c": 2, "d": 3, "e": 4}, "f": 5} + self.assertEqual(result, expected) \ No newline at end of file diff --git a/api/gen_ai/tests/test_webhook_handlers.py b/api/gen_ai/tests/test_webhook_handlers.py new file mode 100644 index 0000000000..32d3a08603 --- /dev/null +++ b/api/gen_ai/tests/test_webhook_handlers.py @@ -0,0 +1,228 @@ +import unittest +from datetime import datetime +from unittest.mock import Mock, patch + +import stripe +from django.test import TestCase + +from billing.views import StripeWebhookHandler +from shared.django_apps.core.tests.factories import OwnerFactory + + +class TestStripeWebhookHandlerExtensions(TestCase): + def setUp(self): + self.owner = OwnerFactory( + stripe_customer_id="cus_123", + stripe_subscription_id="sub_123", + ) + self.handler = StripeWebhookHandler() + + @patch("logging.Logger.error") + def test_payment_intent_payment_failed(self, log_error_mock): + payment_intent = Mock() + payment_intent.customer = self.owner.stripe_customer_id + payment_intent.id = "pi_123" + + self.handler.payment_intent_payment_failed(payment_intent) + + # Verify owner marked as delinquent + self.owner.refresh_from_db() + self.assertTrue(self.owner.delinquent) + + # Verify error was logged + log_error_mock.assert_called_once() + + @patch("services.task.TaskService.send_email") + def test_payment_intent_payment_failed_sends_emails(self, mock_send_email): + # Create admin users + admin1 = OwnerFactory(email="admin1@example.com") + admin2 = OwnerFactory(email="admin2@example.com") + self.owner.admins = [admin1.ownerid, admin2.ownerid] + self.owner.email = "owner@example.com" + self.owner.save() + + payment_intent = Mock() + payment_intent.customer = self.owner.stripe_customer_id + payment_intent.id = "pi_123" + payment_intent.amount_received = 24000 + + self.handler.payment_intent_payment_failed(payment_intent) + + # Verify emails were sent to owner and admins + self.assertEqual(mock_send_email.call_count, 3) + + @patch("logging.Logger.info") + def test_charge_refunded(self, log_info_mock): + charge = Mock() + charge.id = "ch_123" + charge.customer = self.owner.stripe_customer_id + charge.amount_refunded = 5000 + + self.handler.charge_refunded(charge) + + # Verify info was logged + log_info_mock.assert_called_once() + + @patch("services.task.TaskService.send_email") + def test_charge_refunded_sends_emails(self, mock_send_email): + # Create admin users + admin1 = OwnerFactory(email="admin1@example.com") + admin2 = OwnerFactory(email="admin2@example.com") + self.owner.admins = [admin1.ownerid, admin2.ownerid] + self.owner.email = "owner@example.com" + self.owner.save() + + charge = Mock() + charge.id = "ch_123" + charge.customer = self.owner.stripe_customer_id + charge.amount_refunded = 5000 + + self.handler.charge_refunded(charge) + + # Verify emails were sent to owner and admins + self.assertEqual(mock_send_email.call_count, 3) + + @patch("logging.Logger.warning") + @patch("stripe.Charge.retrieve") + def test_dispute_created(self, retrieve_charge_mock, log_warning_mock): + dispute = Mock() + dispute.id = "dp_123" + dispute.charge = "ch_123" + dispute.amount = 5000 + dispute.status = "needs_response" + dispute.reason = "fraudulent" + + charge = Mock() + charge.customer = self.owner.stripe_customer_id + retrieve_charge_mock.return_value = charge + + self.handler.dispute_created(dispute) + + # Verify warning was logged + log_warning_mock.assert_called_once() + + # Verify charge was retrieved + retrieve_charge_mock.assert_called_once_with(dispute.charge) + + @patch("services.task.TaskService.send_email") + @patch("stripe.Charge.retrieve") + def test_dispute_created_sends_emails(self, retrieve_charge_mock, mock_send_email): + # Create admin users + admin1 = OwnerFactory(email="admin1@example.com") + admin2 = OwnerFactory(email="admin2@example.com") + self.owner.admins = [admin1.ownerid, admin2.ownerid] + self.owner.email = "owner@example.com" + self.owner.save() + + dispute = Mock() + dispute.id = "dp_123" + dispute.charge = "ch_123" + dispute.amount = 5000 + dispute.status = "needs_response" + dispute.reason = "fraudulent" + + charge = Mock() + charge.customer = self.owner.stripe_customer_id + retrieve_charge_mock.return_value = charge + + self.handler.dispute_created(dispute) + + # Verify emails were sent to owner and admins + self.assertEqual(mock_send_email.call_count, 3) + + @patch("logging.Logger.info") + def test_invoice_updated(self, log_info_mock): + invoice = Mock() + invoice.id = "in_123" + invoice.customer = self.owner.stripe_customer_id + invoice.status = "open" + invoice.total = 5000 + + self.handler.invoice_updated(invoice) + + # Verify info was logged + log_info_mock.assert_called_once() + + @patch("services.task.TaskService.send_email") + def test_invoice_updated_sends_emails_when_open(self, mock_send_email): + # Create admin users + admin1 = OwnerFactory(email="admin1@example.com") + admin2 = OwnerFactory(email="admin2@example.com") + self.owner.admins = [admin1.ownerid, admin2.ownerid] + self.owner.email = "owner@example.com" + self.owner.save() + + invoice = Mock() + invoice.id = "in_123" + invoice.customer = self.owner.stripe_customer_id + invoice.status = "open" + invoice.total = 5000 + + self.handler.invoice_updated(invoice) + + # Verify emails were sent to owner and admins + self.assertEqual(mock_send_email.call_count, 3) + + @patch("services.task.TaskService.send_email") + def test_invoice_updated_doesnt_send_emails_when_not_open(self, mock_send_email): + # Create admin users + admin1 = OwnerFactory(email="admin1@example.com") + admin2 = OwnerFactory(email="admin2@example.com") + self.owner.admins = [admin1.ownerid, admin2.ownerid] + self.owner.email = "owner@example.com" + self.owner.save() + + invoice = Mock() + invoice.id = "in_123" + invoice.customer = self.owner.stripe_customer_id + invoice.status = "paid" + invoice.total = 5000 + + self.handler.invoice_updated(invoice) + + # Verify no emails were sent since status isn't open + mock_send_email.assert_not_called() + + @patch("logging.Logger.info") + def test_account_updated(self, log_info_mock): + account = Mock() + account.id = "acct_123" + account.email = "test@example.com" + + self.handler.account_updated(account) + + # Verify info was logged + log_info_mock.assert_called_once() + + @patch("logging.Logger.info") + def test_default_event_handler(self, log_info_mock): + event_object = {"id": "evt_123", "type": "unknown.event"} + + self.handler.default_event_handler(event_object) + + # Verify info was logged + log_info_mock.assert_called_once() + + @patch("time.sleep") + @patch("logging.Logger.info") + def test_simulate_delay(self, log_info_mock, sleep_mock): + seconds = 5 + + self.handler.simulate_delay(seconds) + + # Verify delay was simulated + sleep_mock.assert_called_once_with(seconds) + log_info_mock.assert_called_once() + + @patch("shared.plan.service.PlanService") + @patch("logging.Logger.info") + def test_revalidate_subscription(self, log_info_mock, plan_service_mock): + subscription = Mock() + subscription.id = self.owner.stripe_subscription_id + subscription.customer = self.owner.stripe_customer_id + subscription.quantity = 10 + + self.handler.revalidate_subscription(subscription) + + # Verify subscription was revalidated + log_info_mock.assert_called_once() \ No newline at end of file diff --git a/billing/tests/test_views.py b/billing/tests/test_views.py index 8218f9bfc8..208fed85ea 100644 --- a/billing/tests/test_views.py +++ b/billing/tests/test_views.py @@ -907,7 +907,75 @@ def test_customer_subscription_updated_does_not_change_subscription_if_not_paid_ ) @patch("billing.views.StripeWebhookHandler._has_unverified_initial_payment_method") - @patch("logging.Logger.info") + def test_customer_subscription_updated_payment_failed( + self, has_unverified_initial_payment_method_mock + ): + has_unverified_initial_payment_method_mock.return_value = False + self.owner.delinquent = False + self.owner.save() + + self._send_event( + payload={ + "type": "customer.subscription.updated", + "data": { + "object": { + "id": self.owner.stripe_subscription_id, + "customer": self.owner.stripe_customer_id, + "plan": {"id": "?"}, + "metadata": {"obo_organization": self.owner.ownerid}, + "quantity": 20, + "status": "active", + "schedule": None, + "default_payment_method": "pm_1LhiRsGlVGuVgOrkQguJXdeV", + "pending_update": { + "expires_at": 1571194285, + "subscription_items": [ + { + "id": "si_09IkI4u3ZypJUk5onGUZpe8O", + "price": "price_CBb6IXqvTLXp3f", + } + ], + }, + } + }, + } + ) + + self.owner.refresh_from_db() + assert self.owner.delinquent == True + + @patch("billing.views.StripeWebhookHandler.default_event_handler") + def test_post_with_unhandled_event(self, default_event_handler_mock): + """Test that unhandled event types are routed to default_event_handler""" + self._send_event( + payload={ + "type": "unknown.event.type", + "data": {"object": {"id": "test_id"}}, + } + ) + default_event_handler_mock.assert_called_once() + + @patch("logging.Logger.error") + def test_post_error_handling(self, log_error_mock): + """Test error handling in the post method when an event handler raises an exception""" + with patch.object(StripeWebhookHandler, "customer_created", side_effect=Exception("Test error")): + self._send_event( + payload={ + "type": "customer.created", + "data": {"object": {"id": "FOEKDCDEQ", "email": "test@email.com"}}, + } + ) + + # Verify that the error was logged + log_error_mock.assert_called_once_with( + "Error handling event", + extra={"error": "Test error"} + ) + + # Verify response was still successful despite the error + # (This assumes the test_send_event returns the response) + + @patch("billing.views.StripeWebhookHandler._has_unverified_initial_payment_method") @patch("services.billing.stripe.PaymentMethod.attach") @patch("services.billing.stripe.Customer.modify") def test_customer_subscription_updated_does_not_change_subscription_if_there_is_a_schedule(