1+ # Copyright (c) Microsoft Corporation.
2+ # Licensed under the MIT license.
3+
4+ """
5+ Tests for modality support detection in prompt targets.
6+ """
7+
8+ import pytest
9+
10+ from pyrit .memory import CentralMemory
11+ from pyrit .prompt_target .openai .openai_chat_target import OpenAIChatTarget
12+ from pyrit .prompt_target .text_target import TextTarget
13+ from pyrit .prompt_target .hugging_face .hugging_face_chat_target import HuggingFaceChatTarget
14+
15+
16+ class TestModalitySupport :
17+ """
18+ Test cases for modality support detection in prompt targets.
19+ """
20+
21+ @pytest .fixture (autouse = True )
22+ def sqlite_instance (self ):
23+ """Initialize in-memory SQLite database for testing."""
24+ memory_instance = CentralMemory ()
25+ CentralMemory .set_memory_instance (memory_instance )
26+ yield memory_instance
27+
28+ def test_openai_chat_target_supports_multimodal (self , sqlite_instance ):
29+ """
30+ Test that OpenAIChatTarget declares multimodal support correctly.
31+ """
32+ target = OpenAIChatTarget (
33+ endpoint = "https://test.openai.azure.com/" ,
34+ model_name = "gpt-4o" ,
35+ api_key = "test" ,
36+ )
37+
38+ # OpenAI Chat should support both text and image_path
39+ assert target .input_modality_supported ("text" )
40+ assert target .input_modality_supported ("image_path" )
41+ assert not target .input_modality_supported ("audio" )
42+ assert target .supports_multimodal_input ()
43+ assert "text" in target .supported_input_modalities
44+ assert "image_path" in target .supported_input_modalities
45+ assert len (target .supported_input_modalities ) == 2
46+
47+ def test_text_target_supports_text_only (self , sqlite_instance ):
48+ """
49+ Test that TextTarget declares text-only support correctly.
50+ """
51+ target = TextTarget ()
52+
53+ # TextTarget should only support text
54+ assert target .input_modality_supported ("text" )
55+ assert not target .input_modality_supported ("image_path" )
56+ assert not target .input_modality_supported ("audio" )
57+ assert not target .supports_multimodal_input ()
58+ assert target .supported_input_modalities == ["text" ]
59+ assert len (target .supported_input_modalities ) == 1
60+
61+ def test_huggingface_chat_target_supports_text_only (self , sqlite_instance ):
62+ """
63+ Test that HuggingFaceChatTarget declares text-only support correctly.
64+ """
65+ target = HuggingFaceChatTarget (
66+ model_id = "microsoft/DialoGPT-medium" ,
67+ use_cuda = False , # Avoid GPU dependency in tests
68+ )
69+
70+ # HuggingFace Chat should only support text (for now)
71+ assert target .input_modality_supported ("text" )
72+ assert not target .input_modality_supported ("image_path" )
73+ assert not target .input_modality_supported ("audio" )
74+ assert not target .supports_multimodal_input ()
75+ assert target .supported_input_modalities == ["text" ]
76+ assert len (target .supported_input_modalities ) == 1
77+
78+ def test_target_identifier_includes_modality_fields (self , sqlite_instance ):
79+ """
80+ Test that target identifiers include modality support information.
81+ """
82+ # Test multimodal target (OpenAI)
83+ openai_target = OpenAIChatTarget (
84+ endpoint = "https://test.openai.azure.com/" ,
85+ model_name = "gpt-4o" ,
86+ api_key = "test" ,
87+ )
88+ openai_id = openai_target .get_identifier ()
89+ assert openai_id .supported_input_modalities is not None
90+ assert "text" in openai_id .supported_input_modalities
91+ assert "image_path" in openai_id .supported_input_modalities
92+ assert openai_id .supports_conversation_history is True
93+
94+ # Test text-only target
95+ text_target = TextTarget ()
96+ text_id = text_target .get_identifier ()
97+ assert text_id .supported_input_modalities == ["text" ]
98+ assert text_id .supports_conversation_history is True
99+
100+ def test_modality_support_detection_differentiates_targets (self , sqlite_instance ):
101+ """
102+ Test that modality support detection can differentiate between target types.
103+ """
104+ openai_target = OpenAIChatTarget (
105+ endpoint = "https://test.openai.azure.com/" ,
106+ model_name = "gpt-4o" ,
107+ api_key = "test" ,
108+ )
109+ text_target = TextTarget ()
110+
111+ # Test differentiation
112+ assert openai_target .supports_multimodal_input ()
113+ assert not text_target .supports_multimodal_input ()
114+
115+ # Test that they support different modality sets
116+ assert set (openai_target .supported_input_modalities ) != set (text_target .supported_input_modalities )
117+
118+ # Both should support text
119+ assert openai_target .input_modality_supported ("text" )
120+ assert text_target .input_modality_supported ("text" )
121+
122+ # Only OpenAI should support image_path
123+ assert openai_target .input_modality_supported ("image_path" )
124+ assert not text_target .input_modality_supported ("image_path" )
125+
126+ def test_modality_support_properties_are_immutable (self , sqlite_instance ):
127+ """
128+ Test that modality support properties return copies and are not directly modifiable.
129+ """
130+ target = OpenAIChatTarget (
131+ endpoint = "https://test.openai.azure.com/" ,
132+ model_name = "gpt-4o" ,
133+ api_key = "test" ,
134+ )
135+
136+ # Get the list and modify it
137+ modalities = target .supported_input_modalities
138+ original_length = len (modalities )
139+ modalities .append ("fake_modality" )
140+
141+ # Verify the target's actual list wasn't modified
142+ assert len (target .supported_input_modalities ) == original_length
143+ assert "fake_modality" not in target .supported_input_modalities
0 commit comments