diff --git a/dojo/announcement/os_message.py b/dojo/announcement/os_message.py new file mode 100644 index 00000000000..dfdb9288710 --- /dev/null +++ b/dojo/announcement/os_message.py @@ -0,0 +1,119 @@ +import logging + +import bleach +import markdown +import requests +from django.core.cache import cache + +logger = logging.getLogger(__name__) + +BUCKET_URL = "https://storage.googleapis.com/defectdojo-os-messages-prod/open_source_message.md" +CACHE_SECONDS = 3600 +HTTP_TIMEOUT_SECONDS = 2 +CACHE_KEY = "os_message:v1" + +INLINE_TAGS = ["strong", "em", "a"] +INLINE_ATTRS = {"a": ["href", "title"]} + +# Keep BLOCK_TAGS / BLOCK_ATTRS in sync with the DaaS publisher's +# MARKDOWNIFY["default"]["WHITELIST_TAGS"] / WHITELIST_ATTRS so previews +# on DaaS and rendering in OSS stay byte-identical. +BLOCK_TAGS = [ + "p", "ul", "ol", "li", "a", "strong", "em", "code", "pre", + "blockquote", "h2", "h3", "h4", "hr", "br", "b", "i", + "abbr", "acronym", +] +BLOCK_ATTRS = { + "a": ["href", "title"], + "abbr": ["title"], + "acronym": ["title"], +} + +_MISS = object() + + +def fetch_os_message(): + cached = cache.get(CACHE_KEY, default=_MISS) + if cached is not _MISS: + return cached + + try: + response = requests.get(BUCKET_URL, timeout=HTTP_TIMEOUT_SECONDS) + except Exception: + logger.debug("os_message: fetch failed", exc_info=True) + cache.set(CACHE_KEY, None, CACHE_SECONDS) + return None + + if response.status_code != 200 or not response.text.strip(): + cache.set(CACHE_KEY, None, CACHE_SECONDS) + return None + + cache.set(CACHE_KEY, response.text, CACHE_SECONDS) + return response.text + + +def _strip_outer_p(html): + stripped = html.strip() + if stripped.startswith("

") and stripped.endswith("

"): + return stripped[3:-4] + return stripped + + +def parse_os_message(text): + lines = text.splitlines() + + headline_source = None + body_start = None + for index, line in enumerate(lines): + if line.startswith("# "): + headline_source = line[2:].strip() + body_start = index + 1 + break + + if not headline_source: + return None + + headline_source = headline_source[:100] + headline_rendered = markdown.markdown(headline_source) + headline_cleaned = bleach.clean( + headline_rendered, + tags=INLINE_TAGS, + attributes=INLINE_ATTRS, + strip=True, + ) + headline_html = _strip_outer_p(headline_cleaned) + + expanded_html = None + expanded_marker = "## Expanded Message" + expanded_body_lines = None + for offset, line in enumerate(lines[body_start:], start=body_start): + if line.strip() == expanded_marker: + expanded_body_lines = lines[offset + 1:] + break + + if expanded_body_lines is not None: + expanded_source = "\n".join(expanded_body_lines).strip() + if expanded_source: + expanded_rendered = markdown.markdown( + expanded_source, + extensions=["extra", "fenced_code", "nl2br"], + ) + expanded_html = bleach.clean( + expanded_rendered, + tags=BLOCK_TAGS, + attributes=BLOCK_ATTRS, + strip=True, + ) + + return {"message": headline_html, "expanded_html": expanded_html} + + +def get_os_banner(): + try: + text = fetch_os_message() + if not text: + return None + return parse_os_message(text) + except Exception: + logger.debug("os_message: get_os_banner failed", exc_info=True) + return None diff --git a/dojo/announcement/signals.py b/dojo/announcement/signals.py index dedd3444654..c74fd0e5d50 100644 --- a/dojo/announcement/signals.py +++ b/dojo/announcement/signals.py @@ -1,4 +1,3 @@ -from django.conf import settings from django.db.models.signals import post_save from django.dispatch import receiver @@ -7,22 +6,11 @@ @receiver(post_save, sender=Dojo_User) def add_announcement_to_new_user(sender, instance, **kwargs): - announcements = Announcement.objects.all() - if announcements.count() > 0: - dojo_user = Dojo_User.objects.get(id=instance.id) - announcement = announcements.first() - cloud_announcement = ( - "DefectDojo Pro Cloud and On-Premise Subscriptions Now Available!" - in announcement.message + announcement = Announcement.objects.first() + if announcement is not None: + UserAnnouncement.objects.get_or_create( + user=instance, announcement=announcement, ) - if not cloud_announcement or settings.CREATE_CLOUD_BANNER: - user_announcements = UserAnnouncement.objects.filter( - user=dojo_user, announcement=announcement, - ) - if user_announcements.count() == 0: - UserAnnouncement.objects.get_or_create( - user=dojo_user, announcement=announcement, - ) @receiver(post_save, sender=Announcement) diff --git a/dojo/context_processors.py b/dojo/context_processors.py index cc53af0f1e0..792e1eb6b42 100644 --- a/dojo/context_processors.py +++ b/dojo/context_processors.py @@ -5,13 +5,14 @@ from django.conf import settings from django.contrib import messages +from dojo.announcement.os_message import get_os_banner from dojo.labels import get_labels from dojo.models import Alerts, System_Settings, UserAnnouncement def globalize_vars(request): # return the value you want as a dictionnary. you may add multiple values in there. - return { + context = { "SHOW_LOGIN_FORM": settings.SHOW_LOGIN_FORM, "FORGOT_PASSWORD": settings.FORGOT_PASSWORD, "FORGOT_USERNAME": settings.FORGOT_USERNAME, @@ -35,11 +36,32 @@ def globalize_vars(request): "DOCUMENTATION_URL": settings.DOCUMENTATION_URL, "API_TOKENS_ENABLED": settings.API_TOKENS_ENABLED, "API_TOKEN_AUTH_ENDPOINT_ENABLED": settings.API_TOKEN_AUTH_ENDPOINT_ENABLED, - "CREATE_CLOUD_BANNER": settings.CREATE_CLOUD_BANNER, + "SHOW_PLG_LINK": True, # V3 Feature Flags "V3_FEATURE_LOCATIONS": settings.V3_FEATURE_LOCATIONS, } + additional_banners = [] + + if (os_banner := get_os_banner()) is not None: + additional_banners.append({ + "source": "os", + "message": os_banner["message"], + "style": "info", + "url": "", + "link_text": "", + "expanded_html": os_banner["expanded_html"], + }) + + if hasattr(request, "session"): + for banner in request.session.pop("_product_banners", []): + additional_banners.append(banner) + + if additional_banners: + context["additional_banners"] = additional_banners + + return context + def bind_system_settings(request): """Load system settings and display warning if there's a database error.""" diff --git a/dojo/management/commands/complete_initialization.py b/dojo/management/commands/complete_initialization.py index 556c77867fb..9ae58926324 100644 --- a/dojo/management/commands/complete_initialization.py +++ b/dojo/management/commands/complete_initialization.py @@ -14,7 +14,6 @@ from django.db.utils import ProgrammingError from dojo.auditlog import configure_pghistory_triggers -from dojo.models import Announcement, Dojo_User, UserAnnouncement class Command(BaseCommand): @@ -38,13 +37,11 @@ def handle(self, *args: Any, **options: Any) -> None: if self.admin_user_exists(): self.stdout.write("Admin user already exists; skipping first-boot setup") - self.create_announcement_banner() self.initialize_data() return self.ensure_admin_secrets() self.first_boot_setup() - self.create_announcement_banner() self.initialize_data() # ------------------------------------------------------------------ @@ -58,29 +55,6 @@ def initialize_data(self) -> None: self.stdout.write("Initializing non-standard permissions") call_command("initialize_permissions") - def create_announcement_banner(self) -> None: - if os.getenv("DD_CREATE_CLOUD_BANNER"): - return - - self.stdout.write("Creating announcement banner") - - announcement, _ = Announcement.objects.get_or_create(id=1) - announcement.message = ( - '' - "DefectDojo Pro Cloud and On-Premise Subscriptions Now Available! " - "Create an account to try Pro for free!" - "" - ) - announcement.dismissable = True - announcement.save() - - for user in Dojo_User.objects.all(): - UserAnnouncement.objects.get_or_create( - user=user, - announcement=announcement, - ) - # ------------------------------------------------------------------ # Auditlog consistency # ------------------------------------------------------------------ diff --git a/dojo/product_announcements.py b/dojo/product_announcements.py index 8510b42a0f8..90280885007 100644 --- a/dojo/product_announcements.py +++ b/dojo/product_announcements.py @@ -1,8 +1,6 @@ import logging -from django.conf import settings -from django.contrib import messages from django.http import HttpRequest, HttpResponse from django.utils.safestring import mark_safe from django.utils.translation import gettext_lazy as _ @@ -30,12 +28,8 @@ def __init__( response_data: dict | None = None, **kwargs: dict, ): - """Skip all this if the CREATE_CLOUD_BANNER is not set""" - if not settings.CREATE_CLOUD_BANNER: - return - # Fill in the vars if the were supplied correctly if request is not None and isinstance(request, HttpRequest): - self._add_django_message( + self._add_session_banner( request=request, message=mark_safe(f"{self.base_message} {self.ui_outreach}"), ) @@ -51,18 +45,21 @@ def __init__( msg = "At least one of request, response, or response_data must be supplied" raise ValueError(msg) - def _add_django_message(self, request: HttpRequest, message: str): - """Add a message to the UI""" + def _add_session_banner(self, request: HttpRequest, message: str): + """Store a banner in the session for rendering via additional_banners.""" try: - messages.add_message( - request=request, - level=messages.INFO, - message=_(message), - extra_tags="alert-info", - ) + banners = request.session.get("_product_banners", []) + banners.append({ + "source": "product_announcement", + "message": str(_(message)), + "style": "info", + "url": "", + "link_text": "", + "expanded_html": None, + }) + request.session["_product_banners"] = banners except Exception: - # make sure we catch any exceptions that might happen: https://github.com/DefectDojo/django-DefectDojo/issues/14041 - logger.exception(f"Error adding message to Django: {message}") + logger.exception(f"Error storing product announcement banner: {message}") def _add_api_response_key(self, message: str, data: dict) -> dict: """Update the response data in place""" diff --git a/dojo/settings/settings.dist.py b/dojo/settings/settings.dist.py index 429b5646b90..7dd95aae1f4 100644 --- a/dojo/settings/settings.dist.py +++ b/dojo/settings/settings.dist.py @@ -356,8 +356,6 @@ DD_HASHCODE_FIELDS_PER_SCANNER=(str, ""), # Set deduplication algorithms per parser, via en env variable that contains a JSON string DD_DEDUPLICATION_ALGORITHM_PER_PARSER=(str, ""), - # Dictates whether cloud banner is created or not - DD_CREATE_CLOUD_BANNER=(bool, True), # With this setting turned on, Dojo maintains an audit log of changes made to entities (Findings, Tests, Engagements, Products, ...) # If you run big import you may want to disable this because there's a performance hit during (re-)imports. DD_ENABLE_AUDITLOG=(bool, True), @@ -1339,13 +1337,6 @@ def saml2_attrib_map_format(din): "expires": int(60 * 1 * 1.2), # If a task is not executed within 72 seconds, it should be dropped from the queue. Two more tasks should be scheduled in the meantime. }, }, - "trigger_evaluate_pro_proposition": { - "task": "dojo.tasks.evaluate_pro_proposition", - "schedule": timedelta(hours=8), - "options": { - "expires": int(60 * 60 * 8 * 1.2), # If a task is not executed within 9.6 hours, it should be dropped from the queue. Two more tasks should be scheduled in the meantime. - }, - }, "clear_sessions": { "task": "dojo.tasks.clear_sessions", "schedule": crontab(hour=0, minute=0, day_of_week=0), @@ -2082,9 +2073,6 @@ def saml2_attrib_map_format(din): AUDITLOG_DISABLE_ON_RAW_SAVE = False # You can set extra Jira headers by suppling a dictionary in header: value format (pass as env var like "headr_name=value,another_header=anohter_value") ADDITIONAL_HEADERS = env("DD_ADDITIONAL_HEADERS") -# Dictates whether cloud banner is created or not -CREATE_CLOUD_BANNER = env("DD_CREATE_CLOUD_BANNER") - # ------------------------------------------------------------------------------ # Auditlog # ------------------------------------------------------------------------------ diff --git a/dojo/static/dojo/css/dojo.css b/dojo/static/dojo/css/dojo.css index deba72474c9..fea91da6c94 100644 --- a/dojo/static/dojo/css/dojo.css +++ b/dojo/static/dojo/css/dojo.css @@ -1124,6 +1124,41 @@ div.custom-search-form { .announcement-banner { margin: 0px -15px; border-radius: 0px 0px 4px 4px; + color: #000; +} + +.announcement-banner a { + color: #0645ad; + text-decoration: underline; +} + +.announcement-banner strong, +.announcement-banner b { + color: #222; +} + +.banner-toggle { + background: transparent; + border: 0; + padding: 0; + margin-left: 6px; + color: inherit; + cursor: pointer; + line-height: 1; +} + +.banner-toggle:focus, +.banner-toggle:active { + outline: none; + box-shadow: none; +} + +.banner-toggle:not(.collapsed) .fa-caret-down { + transform: rotate(180deg); +} + +.banner-expanded { + margin-top: 8px; } @media (min-width: 795px) { diff --git a/dojo/tasks.py b/dojo/tasks.py index 5a494072a8b..d1f90275dc2 100644 --- a/dojo/tasks.py +++ b/dojo/tasks.py @@ -16,9 +16,8 @@ from dojo.celery import app from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding.helper import fix_loop_duplicates -from dojo.location.models import Location from dojo.management.commands.jira_status_reconciliation import jira_status_reconciliation -from dojo.models import Alerts, Announcement, Endpoint, Engagement, Finding, Product, System_Settings, User +from dojo.models import Alerts, Engagement, Finding, Product, System_Settings, User from dojo.notifications.helper import create_notification from dojo.utils import calculate_grade, sla_compute_and_notify @@ -218,37 +217,6 @@ def fix_loop_duplicates_task(*args, **kwargs): return fix_loop_duplicates() -@app.task -def evaluate_pro_proposition(*args, **kwargs): - # Ensure we should be doing this - if not settings.CREATE_CLOUD_BANNER: - return - # Get the announcement object - announcement = Announcement.objects.get_or_create(id=1)[0] - # Quick check for a user has modified the current banner - if not, exit early as we dont want to stomp - if not any( - entry in announcement.message - for entry in [ - "", - "DefectDojo Pro Cloud and On-Premise Subscriptions Now Available!", - "Findings/Endpoints in their systems", - ] - ): - return - # Count the objects the determine if the banner should be updated - if settings.V3_FEATURE_LOCATIONS: - object_count = Finding.objects.count() + Location.objects.count() - else: - # TODO: Delete this after the move to Locations - object_count = Finding.objects.count() + Endpoint.objects.count() - # Unless the count is greater than 100k, exit early - if object_count < 100000: - return - # Update the announcement - announcement.message = f'Only professionals have {object_count:,} Findings and Endpoints in their systems... Get DefectDojo Pro today!' - announcement.save() - - @app.task def clear_sessions(*args, **kwargs): call_command("clearsessions") diff --git a/dojo/templates/base.html b/dojo/templates/base.html index 54d6e0e58af..b8f5489f1d7 100644 --- a/dojo/templates/base.html +++ b/dojo/templates/base.html @@ -199,7 +199,7 @@ {% endif %} - {% if CREATE_CLOUD_BANNER %} + {% if SHOW_PLG_LINK %}
  • @@ -671,8 +671,21 @@ {% endif %} {% for banner in additional_banners %} -
  • Major feature A
  • ", result["expanded_html"]) + self.assertIn("
  • Major feature B
  • ", result["expanded_html"]) + + def test_missing_headline_returns_none(self): + text = "No headline here\n## Expanded Message\nbody\n" + self.assertIsNone(os_message.parse_os_message(text)) + + def test_headline_inline_markdown(self): + text = "# Read the **release notes** at [link](https://example.com)\n" + result = os_message.parse_os_message(text) + self.assertIn("release notes", result["message"]) + self.assertIn('link', result["message"]) + self.assertIsNone(result["expanded_html"]) + + def test_headline_strips_disallowed_html(self): + text = "# Headline tail\n" + result = os_message.parse_os_message(text) + self.assertNotIn("", result["message"]) + self.assertIn("Headline", result["message"]) + + def test_missing_expanded_section(self): + text = "# Just a headline\n" + result = os_message.parse_os_message(text) + self.assertEqual(result["message"], "Just a headline") + self.assertIsNone(result["expanded_html"]) + + def test_expanded_with_fenced_code(self): + text = ( + "# Headline\n" + "## Expanded Message\n" + "```python\n" + "print('hi')\n" + "```\n" + ) + result = os_message.parse_os_message(text) + self.assertIn("
    ", result["expanded_html"])
    +        self.assertIn("", result["expanded_html"])
    +        self.assertIn("print('hi')", result["expanded_html"])
    +
    +    def test_expanded_strips_script_tag(self):
    +        text = (
    +            "# Headline\n"
    +            "## Expanded Message\n"
    +            "\n"
    +            "Body paragraph\n"
    +        )
    +        result = os_message.parse_os_message(text)
    +        self.assertNotIn("", result["expanded_html"])
    +        self.assertIn("Body paragraph", result["expanded_html"])
    +
    +    def test_headline_outer_p_is_stripped(self):
    +        text = "# Plain headline\n"
    +        result = os_message.parse_os_message(text)
    +        self.assertFalse(result["message"].startswith("

    ")) + self.assertFalse(result["message"].endswith("

    ")) + + def test_headline_truncated_to_100_chars(self): + long_headline = "x" * 200 + text = f"# {long_headline}\n" + result = os_message.parse_os_message(text) + self.assertLessEqual(len(result["message"]), 100) + + +@override_settings(CACHES={"default": {"BACKEND": "django.core.cache.backends.locmem.LocMemCache"}}) +class TestFetchOsMessage(SimpleTestCase): + + def setUp(self): + cache.clear() + + def test_200_with_body_caches_body(self): + body = "# headline\n" + with patch("dojo.announcement.os_message.requests.get", return_value=_Resp(200, body)) as mock_get: + result = os_message.fetch_os_message() + self.assertEqual(result, body) + self.assertEqual(cache.get(os_message.CACHE_KEY), body) + mock_get.assert_called_once() + + def test_404_caches_none(self): + with patch("dojo.announcement.os_message.requests.get", return_value=_Resp(404, "not found")): + result = os_message.fetch_os_message() + self.assertIsNone(result) + self.assertIsNone(cache.get(os_message.CACHE_KEY, default="sentinel")) + + def test_timeout_caches_none(self): + with patch("dojo.announcement.os_message.requests.get", side_effect=requests.exceptions.Timeout): + result = os_message.fetch_os_message() + self.assertIsNone(result) + self.assertIsNone(cache.get(os_message.CACHE_KEY, default="sentinel")) + + def test_connection_error_caches_none(self): + with patch("dojo.announcement.os_message.requests.get", side_effect=requests.exceptions.ConnectionError): + result = os_message.fetch_os_message() + self.assertIsNone(result) + self.assertIsNone(cache.get(os_message.CACHE_KEY, default="sentinel")) + + def test_empty_body_caches_none(self): + with patch("dojo.announcement.os_message.requests.get", return_value=_Resp(200, " \n\n")): + result = os_message.fetch_os_message() + self.assertIsNone(result) + self.assertIsNone(cache.get(os_message.CACHE_KEY, default="sentinel")) + + def test_second_call_hits_cache(self): + with patch("dojo.announcement.os_message.requests.get", return_value=_Resp(200, "# h\n")) as mock_get: + os_message.fetch_os_message() + os_message.fetch_os_message() + self.assertEqual(mock_get.call_count, 1) + + def test_second_call_after_failure_also_hits_cache(self): + with patch("dojo.announcement.os_message.requests.get", side_effect=requests.exceptions.Timeout) as mock_get: + os_message.fetch_os_message() + os_message.fetch_os_message() + self.assertEqual(mock_get.call_count, 1) + + +@override_settings(CACHES={"default": {"BACKEND": "django.core.cache.backends.locmem.LocMemCache"}}) +class TestGetOsBanner(SimpleTestCase): + + def setUp(self): + cache.clear() + + def test_returns_none_when_fetch_returns_none(self): + with patch("dojo.announcement.os_message.fetch_os_message", return_value=None): + self.assertIsNone(os_message.get_os_banner()) + + def test_swallows_parse_exception(self): + with patch("dojo.announcement.os_message.fetch_os_message", return_value="# ok\n"), \ + patch("dojo.announcement.os_message.parse_os_message", side_effect=RuntimeError("boom")): + self.assertIsNone(os_message.get_os_banner()) + + +@override_settings(CACHES={"default": {"BACKEND": "django.core.cache.backends.locmem.LocMemCache"}}) +class TestGlobalizeVarsOsBanner(SimpleTestCase): + + def setUp(self): + cache.clear() + self.request = RequestFactory().get("/") + + def test_additional_banners_populated_when_banner_present(self): + banner = {"message": "Hi", "expanded_html": "

    body

    "} + with patch.object(context_processors, "get_os_banner", return_value=banner): + result = context_processors.globalize_vars(self.request) + self.assertIn("additional_banners", result) + entry = result["additional_banners"][0] + self.assertEqual(entry["source"], "os") + self.assertEqual(entry["message"], "Hi") + self.assertEqual(entry["expanded_html"], "

    body

    ") + self.assertEqual(entry["style"], "info") + self.assertEqual(entry["url"], "") + self.assertEqual(entry["link_text"], "") + + def test_additional_banners_absent_when_no_banner(self): + with patch.object(context_processors, "get_os_banner", return_value=None): + result = context_processors.globalize_vars(self.request) + self.assertNotIn("additional_banners", result) + + def test_show_plg_link_is_true_by_default(self): + with patch.object(context_processors, "get_os_banner", return_value=None): + result = context_processors.globalize_vars(self.request) + self.assertTrue(result["SHOW_PLG_LINK"]) + + def test_create_cloud_banner_not_in_context(self): + with patch.object(context_processors, "get_os_banner", return_value=None): + result = context_processors.globalize_vars(self.request) + self.assertNotIn("CREATE_CLOUD_BANNER", result) + + def test_template_renders_bleached_message(self): + banner = {"message": "Hi", "expanded_html": None} + with patch.object(context_processors, "get_os_banner", return_value=banner): + ctx = context_processors.globalize_vars(self.request) + rendered = Template( + "{% for b in additional_banners %}{{ b.message|safe }}{% endfor %}", + ).render(Context(ctx)) + self.assertIn("Hi", rendered) + + def test_session_product_banners_merged_into_additional_banners(self): + session_banner = { + "source": "product_announcement", + "message": "Pro has async imports!", + "style": "info", + "url": "", + "link_text": "", + "expanded_html": None, + } + self.request.session = {"_product_banners": [session_banner]} + with patch.object(context_processors, "get_os_banner", return_value=None): + result = context_processors.globalize_vars(self.request) + self.assertIn("additional_banners", result) + self.assertEqual(len(result["additional_banners"]), 1) + self.assertEqual(result["additional_banners"][0]["source"], "product_announcement") + self.assertEqual(self.request.session.get("_product_banners"), None) + + def test_os_and_session_banners_combined(self): + os_banner = {"message": "OS msg", "expanded_html": None} + session_banner = { + "source": "product_announcement", + "message": "Pro msg", + "style": "info", + "url": "", + "link_text": "", + "expanded_html": None, + } + self.request.session = {"_product_banners": [session_banner]} + with patch.object(context_processors, "get_os_banner", return_value=os_banner): + result = context_processors.globalize_vars(self.request) + self.assertEqual(len(result["additional_banners"]), 2) + self.assertEqual(result["additional_banners"][0]["source"], "os") + self.assertEqual(result["additional_banners"][1]["source"], "product_announcement") diff --git a/unittests/test_product_announcements.py b/unittests/test_product_announcements.py new file mode 100644 index 00000000000..d9aa39a191f --- /dev/null +++ b/unittests/test_product_announcements.py @@ -0,0 +1,214 @@ +from collections import UserDict + +from django.http import HttpRequest, HttpResponse +from django.test import SimpleTestCase + +from dojo.product_announcements import ( + ErrorPageProductAnnouncement, + LargeScanSizeProductAnnouncement, + LongRunningRequestProductAnnouncement, + ScanTypeProductAnnouncement, +) + + +class _SessionDict(UserDict): + + """Minimal session stand-in that supports .get/.pop/[] like Django sessions.""" + + +def _make_request(): + request = HttpRequest() + request.session = _SessionDict() + return request + + +def _make_response(data=None): + response = HttpResponse() + response.data = data if data is not None else {} + return response + + +class TestProductAnnouncementSessionBanner(SimpleTestCase): + + def test_stores_banner_in_session(self): + request = _make_request() + ErrorPageProductAnnouncement(request=request) + banners = request.session["_product_banners"] + self.assertEqual(len(banners), 1) + self.assertEqual(banners[0]["source"], "product_announcement") + self.assertEqual(banners[0]["style"], "info") + self.assertIn("Pro comes with support.", banners[0]["message"]) + self.assertIsNone(banners[0]["expanded_html"]) + + def test_multiple_announcements_accumulate_in_session(self): + request = _make_request() + ErrorPageProductAnnouncement(request=request) + ErrorPageProductAnnouncement(request=request) + banners = request.session["_product_banners"] + self.assertEqual(len(banners), 2) + + def test_banner_message_contains_outreach_link(self): + request = _make_request() + ErrorPageProductAnnouncement(request=request) + message = request.session["_product_banners"][0]["message"] + self.assertIn("cloud.defectdojo.com", message) + self.assertIn("Try today for free", message) + + def test_session_error_is_swallowed(self): + request = HttpRequest() + request.session = None + ErrorPageProductAnnouncement(request=request) + + def test_no_settings_guard(self): + """Product announcements fire without any settings check.""" + request = _make_request() + ErrorPageProductAnnouncement(request=request) + self.assertEqual(len(request.session["_product_banners"]), 1) + + +class TestProductAnnouncementApiPath(SimpleTestCase): + + def test_api_response_gets_pro_key(self): + response = _make_response(data={}) + ErrorPageProductAnnouncement(response=response) + self.assertIn("pro", response.data) + self.assertEqual(len(response.data["pro"]), 1) + self.assertIn("Pro comes with support.", str(response.data["pro"][0])) + + def test_api_response_appends_to_existing_pro_list(self): + response = _make_response(data={"pro": ["existing"]}) + ErrorPageProductAnnouncement(response=response) + self.assertEqual(len(response.data["pro"]), 2) + self.assertEqual(response.data["pro"][0], "existing") + + def test_api_response_data_dict_gets_pro_key(self): + data = {} + LargeScanSizeProductAnnouncement(response_data=data, duration=120.0) + self.assertIn("pro", data) + + def test_requires_at_least_one_target(self): + with self.assertRaises(ValueError): + ErrorPageProductAnnouncement() + + +class TestErrorPageProductAnnouncement(SimpleTestCase): + + def test_message_content(self): + request = _make_request() + ErrorPageProductAnnouncement(request=request) + message = request.session["_product_banners"][0]["message"] + self.assertIn("Pro comes with support.", message) + + def test_api_path(self): + response = _make_response() + ErrorPageProductAnnouncement(response=response) + self.assertIn("Pro comes with support.", str(response.data["pro"][0])) + + +class TestLargeScanSizeProductAnnouncement(SimpleTestCase): + + def test_fires_when_duration_exceeds_threshold(self): + request = _make_request() + LargeScanSizeProductAnnouncement(request=request, duration=120.0) + banners = request.session["_product_banners"] + self.assertEqual(len(banners), 1) + self.assertIn("import took about 2 minute(s)", banners[0]["message"]) + self.assertIn("async imports", banners[0]["message"]) + + def test_does_not_fire_when_duration_below_threshold(self): + request = _make_request() + LargeScanSizeProductAnnouncement(request=request, duration=30.0) + self.assertEqual(len(request.session.get("_product_banners", [])), 0) + + def test_fires_at_boundary(self): + request = _make_request() + LargeScanSizeProductAnnouncement(request=request, duration=60.1) + self.assertEqual(len(request.session["_product_banners"]), 1) + + def test_does_not_fire_at_exact_threshold(self): + request = _make_request() + LargeScanSizeProductAnnouncement(request=request, duration=60.0) + self.assertEqual(len(request.session.get("_product_banners", [])), 0) + + +class TestLongRunningRequestProductAnnouncement(SimpleTestCase): + + def test_fires_when_duration_exceeds_threshold(self): + request = _make_request() + LongRunningRequestProductAnnouncement(request=request, duration=20.0) + banners = request.session["_product_banners"] + self.assertEqual(len(banners), 1) + self.assertIn("performance tested", banners[0]["message"]) + + def test_does_not_fire_when_duration_below_threshold(self): + request = _make_request() + LongRunningRequestProductAnnouncement(request=request, duration=10.0) + self.assertEqual(len(request.session.get("_product_banners", [])), 0) + + def test_does_not_fire_at_exact_threshold(self): + request = _make_request() + LongRunningRequestProductAnnouncement(request=request, duration=15.0) + self.assertEqual(len(request.session.get("_product_banners", [])), 0) + + +class TestScanTypeProductAnnouncement(SimpleTestCase): + + def test_fires_for_supported_scan_type(self): + request = _make_request() + ScanTypeProductAnnouncement(request=request, scan_type="Snyk Scan") + banners = request.session["_product_banners"] + self.assertEqual(len(banners), 1) + self.assertIn("Snyk Scan", banners[0]["message"]) + self.assertIn("no-code connector", banners[0]["message"]) + + def test_does_not_fire_for_unsupported_scan_type(self): + request = _make_request() + ScanTypeProductAnnouncement(request=request, scan_type="Unknown Scanner") + self.assertEqual(len(request.session.get("_product_banners", [])), 0) + + def test_does_not_fire_for_none_scan_type(self): + request = _make_request() + ScanTypeProductAnnouncement(request=request, scan_type=None) + self.assertEqual(len(request.session.get("_product_banners", [])), 0) + + def test_all_supported_scan_types_fire(self): + for scan_type in ScanTypeProductAnnouncement.supported_scan_types: + request = _make_request() + ScanTypeProductAnnouncement(request=request, scan_type=scan_type) + self.assertEqual( + len(request.session["_product_banners"]), 1, + f"Expected banner for {scan_type}", + ) + + def test_api_path_for_supported_scan_type(self): + data = {} + ScanTypeProductAnnouncement(response_data=data, scan_type="Wiz Scan") + self.assertIn("pro", data) + self.assertIn("Wiz Scan", str(data["pro"][0])) + + +class TestBannerDictSchema(SimpleTestCase): + + """Verify every banner stored in the session has the expected keys.""" + + EXPECTED_KEYS = {"source", "message", "style", "url", "link_text", "expanded_html"} + + def test_error_page_banner_has_all_keys(self): + request = _make_request() + ErrorPageProductAnnouncement(request=request) + self.assertEqual(set(request.session["_product_banners"][0].keys()), self.EXPECTED_KEYS) + + def test_large_scan_banner_has_all_keys(self): + request = _make_request() + LargeScanSizeProductAnnouncement(request=request, duration=120.0) + self.assertEqual(set(request.session["_product_banners"][0].keys()), self.EXPECTED_KEYS) + + def test_long_running_banner_has_all_keys(self): + request = _make_request() + LongRunningRequestProductAnnouncement(request=request, duration=20.0) + self.assertEqual(set(request.session["_product_banners"][0].keys()), self.EXPECTED_KEYS) + + def test_scan_type_banner_has_all_keys(self): + request = _make_request() + ScanTypeProductAnnouncement(request=request, scan_type="Snyk Scan") + self.assertEqual(set(request.session["_product_banners"][0].keys()), self.EXPECTED_KEYS)