diff --git a/build_scripts/evaluate_scorers.py b/build_scripts/evaluate_scorers.py index 3be1cd35be..27ba1a74e1 100644 --- a/build_scripts/evaluate_scorers.py +++ b/build_scripts/evaluate_scorers.py @@ -61,12 +61,12 @@ async def evaluate_scorers(tags: list[str] | None = None, max_concurrency: int = if tags: scorer_names: list[str] = [] for tag in tags: - entries = registry.get_by_tag(tag=tag) + entries = registry.instances.get_by_tag(tag=tag) scorer_names.extend(entry.name for entry in entries if entry.name not in scorer_names) scorer_names.sort() print(f"\nFiltering by tags: {tags}") else: - scorer_names = registry.get_names() + scorer_names = registry.instances.get_names() if not scorer_names: print("No scorers registered. Check environment variable configuration.") @@ -85,7 +85,7 @@ async def evaluate_scorers(tags: list[str] | None = None, max_concurrency: int = # Evaluate each scorer for i, scorer_name in scorer_iterator: - scorer = registry.get_instance_by_name(scorer_name) + scorer = registry.instances.get(scorer_name) print(f"\n[{i}/{len(scorer_names)}] Evaluating {scorer_name}...") print(" Status: Starting evaluation (this may take several minutes)...") diff --git a/doc/code/registry/2_instance_registry.ipynb b/doc/code/registry/2_instance_registry.ipynb index 67161d0827..db9df530f6 100644 --- a/doc/code/registry/2_instance_registry.ipynb +++ b/doc/code/registry/2_instance_registry.ipynb @@ -22,7 +22,7 @@ "source": [ "## Listing Available Instances\n", "\n", - "Use `get_names()` to see registered instances, or `list_metadata()` for details." + "Use `instances.get_names()` to see registered instances, or `instances.list_metadata()` for details." ] }, { @@ -69,10 +69,10 @@ "# Register a scorer instance for demonstration\n", "chat_target = OpenAIChatTarget()\n", "refusal_scorer = SelfAskRefusalScorer(chat_target=chat_target)\n", - "registry.register_instance(refusal_scorer)\n", + "registry.instances.register(refusal_scorer)\n", "\n", "# List what's available\n", - "names = registry.get_names()\n", + "names = registry.instances.get_names()\n", "print(f\"Registered scorers: {names}\")" ] }, @@ -83,7 +83,7 @@ "source": [ "## Getting an Instance\n", "\n", - "Use `get()` to retrieve a pre-configured instance by name. The instance is ready to use immediately." + "Use `instances.get()` to retrieve a pre-configured instance by name. The instance is ready to use immediately." ] }, { @@ -105,7 +105,7 @@ "# Get the first registered scorer\n", "if names:\n", " scorer_name = names[0]\n", - " scorer = registry.get(scorer_name)\n", + " scorer = registry.instances.get(scorer_name)\n", " print(f\"Retrieved scorer: {scorer}\")\n", " print(f\"Scorer type: {type(scorer).__name__}\")" ] @@ -151,7 +151,7 @@ "from pyrit.output import output_scorer_async\n", "\n", "# Get metadata for all registered scorers\n", - "metadata = registry.list_metadata()\n", + "metadata = registry.instances.list_metadata()\n", "for item in metadata:\n", " print(f\"\\n{item.unique_name}:\")\n", " print(f\" Class: {item.class_name}\")\n", @@ -188,15 +188,15 @@ ], "source": [ "# Filter by scorer_type (based on isinstance check against TrueFalseScorer/FloatScaleScorer)\n", - "true_false_scorers = registry.list_metadata(include_filters={\"scorer_type\": \"true_false\"})\n", + "true_false_scorers = registry.instances.list_metadata(include_filters={\"scorer_type\": \"true_false\"})\n", "print(f\"True/False scorers: {[m.unique_name for m in true_false_scorers]}\")\n", "\n", "# Filter by class_name\n", - "refusal_scorers = registry.list_metadata(include_filters={\"class_name\": \"SelfAskRefusalScorer\"})\n", + "refusal_scorers = registry.instances.list_metadata(include_filters={\"class_name\": \"SelfAskRefusalScorer\"})\n", "print(f\"Refusal scorers: {[m.unique_name for m in refusal_scorers]}\")\n", "\n", "# Combine multiple filters (AND logic)\n", - "specific_scorers = registry.list_metadata(\n", + "specific_scorers = registry.instances.list_metadata(\n", " include_filters={\"scorer_type\": \"true_false\", \"class_name\": \"SelfAskRefusalScorer\"}\n", ")\n", "print(f\"True/False refusal scorers: {[m.unique_name for m in specific_scorers]}\")" @@ -248,7 +248,7 @@ "# Get the registry singleton\n", "registry = TargetRegistry.get_registry_singleton()\n", "# List registered targets\n", - "target_names = registry.get_names()\n", + "target_names = registry.instances.get_names()\n", "print(f\"Registered targets after initialization: {target_names}\")" ] } diff --git a/doc/code/registry/2_instance_registry.py b/doc/code/registry/2_instance_registry.py index 4eebeb3b89..b595a0bb48 100644 --- a/doc/code/registry/2_instance_registry.py +++ b/doc/code/registry/2_instance_registry.py @@ -21,7 +21,7 @@ # %% [markdown] # ## Listing Available Instances # -# Use `get_names()` to see registered instances, or `list_metadata()` for details. +# Use `instances.get_names()` to see registered instances, or `instances.list_metadata()` for details. # %% from pyrit.prompt_target import OpenAIChatTarget @@ -37,22 +37,22 @@ # Register a scorer instance for demonstration chat_target = OpenAIChatTarget() refusal_scorer = SelfAskRefusalScorer(chat_target=chat_target) -registry.register_instance(refusal_scorer) +registry.instances.register(refusal_scorer) # List what's available -names = registry.get_names() +names = registry.instances.get_names() print(f"Registered scorers: {names}") # %% [markdown] # ## Getting an Instance # -# Use `get()` to retrieve a pre-configured instance by name. The instance is ready to use immediately. +# Use `instances.get()` to retrieve a pre-configured instance by name. The instance is ready to use immediately. # %% # Get the first registered scorer if names: scorer_name = names[0] - scorer = registry.get(scorer_name) + scorer = registry.instances.get(scorer_name) print(f"Retrieved scorer: {scorer}") print(f"Scorer type: {type(scorer).__name__}") @@ -65,7 +65,7 @@ from pyrit.output import output_scorer_async # Get metadata for all registered scorers -metadata = registry.list_metadata() +metadata = registry.instances.list_metadata() for item in metadata: print(f"\n{item.unique_name}:") print(f" Class: {item.class_name}") @@ -80,15 +80,15 @@ # %% # Filter by scorer_type (based on isinstance check against TrueFalseScorer/FloatScaleScorer) -true_false_scorers = registry.list_metadata(include_filters={"scorer_type": "true_false"}) +true_false_scorers = registry.instances.list_metadata(include_filters={"scorer_type": "true_false"}) print(f"True/False scorers: {[m.unique_name for m in true_false_scorers]}") # Filter by class_name -refusal_scorers = registry.list_metadata(include_filters={"class_name": "SelfAskRefusalScorer"}) +refusal_scorers = registry.instances.list_metadata(include_filters={"class_name": "SelfAskRefusalScorer"}) print(f"Refusal scorers: {[m.unique_name for m in refusal_scorers]}") # Combine multiple filters (AND logic) -specific_scorers = registry.list_metadata( +specific_scorers = registry.instances.list_metadata( include_filters={"scorer_type": "true_false", "class_name": "SelfAskRefusalScorer"} ) print(f"True/False refusal scorers: {[m.unique_name for m in specific_scorers]}") @@ -111,5 +111,5 @@ # Get the registry singleton registry = TargetRegistry.get_registry_singleton() # List registered targets -target_names = registry.get_names() +target_names = registry.instances.get_names() print(f"Registered targets after initialization: {target_names}") diff --git a/doc/code/scenarios/0_scenarios.ipynb b/doc/code/scenarios/0_scenarios.ipynb index a4f64360d1..716deb78aa 100644 --- a/doc/code/scenarios/0_scenarios.ipynb +++ b/doc/code/scenarios/0_scenarios.ipynb @@ -336,8 +336,8 @@ " ``supported_parameters``). Each target must already be registered in\n", " ``TargetRegistry`` — typically by ``TargetInitializer`` from\n", " ``ADVERSARIAL_CHAT_*`` env vars, or programmatically via\n", - " ``TargetRegistry.register_instance``. At run time,\n", - " ``_get_atomic_attacks_async`` performs the ``(technique ×\n", + " ``TargetRegistry.get_registry_singleton().instances.register``. At run\n", + " time, ``_get_atomic_attacks_async`` performs the ``(technique ×\n", " adversarial_target × dataset)`` cross-product: for each selected\n", " adversarial-capable ``core`` factory in the ``AttackTechniqueRegistry``\n", " and each requested target, it calls\n", @@ -356,7 +356,7 @@ " Default Datasets (1, max 8 per dataset):\n", " harmbench\n", " Supported Parameters:\n", - " - adversarial_targets (list[str]): Registry names of adversarial chat targets to benchmark. Each name must already be registered in TargetRegistry (via TargetInitializer or TargetRegistry.register_instance). Use 'pyrit_scan list-targets' to see registered targets. Settable via --adversarial-targets [ ...] on the CLI, or scenario.args.adversarial_targets in .pyrit_conf.\n", + " - adversarial_targets (list[str]): Registry names of adversarial chat targets to benchmark. Each name must already be registered in TargetRegistry (via TargetInitializer or TargetRegistry instance registration). Use 'pyrit_scan list-targets' to see registered targets. Settable via --adversarial-targets [ ...] on the CLI, or scenario.args.adversarial_targets in .pyrit_conf.\n", "\u001b[1m\u001b[36m\n", " foundry.red_team_agent\u001b[0m\n", " Class: RedTeamAgent\n", diff --git a/doc/code/scenarios/1_common_scenario_parameters.ipynb b/doc/code/scenarios/1_common_scenario_parameters.ipynb index 2cd6692b27..0980df4a78 100644 --- a/doc/code/scenarios/1_common_scenario_parameters.ipynb +++ b/doc/code/scenarios/1_common_scenario_parameters.ipynb @@ -85,7 +85,7 @@ "\n", "await initialize_from_config_async(config_path=Path(\"../../scanner/pyrit_conf.yaml\")) # type: ignore\n", "\n", - "objective_target = TargetRegistry.get_registry_singleton().get_instance_by_name(\"openai_chat\")" + "objective_target = TargetRegistry.get_registry_singleton().instances.get(\"openai_chat\")" ] }, { diff --git a/doc/code/scenarios/1_common_scenario_parameters.py b/doc/code/scenarios/1_common_scenario_parameters.py index 28d83a6527..908c1c0b42 100644 --- a/doc/code/scenarios/1_common_scenario_parameters.py +++ b/doc/code/scenarios/1_common_scenario_parameters.py @@ -35,7 +35,7 @@ await initialize_from_config_async(config_path=Path("../../scanner/pyrit_conf.yaml")) # type: ignore -objective_target = TargetRegistry.get_registry_singleton().get_instance_by_name("openai_chat") +objective_target = TargetRegistry.get_registry_singleton().instances.get("openai_chat") # %% [markdown] # ## Dataset Configuration # diff --git a/doc/code/scenarios/3_adaptive_scenarios.ipynb b/doc/code/scenarios/3_adaptive_scenarios.ipynb index 259d4470f7..6c2d13caff 100644 --- a/doc/code/scenarios/3_adaptive_scenarios.ipynb +++ b/doc/code/scenarios/3_adaptive_scenarios.ipynb @@ -115,7 +115,7 @@ "\n", "await initialize_from_config_async(config_path=Path(\"../../scanner/pyrit_conf.yaml\")) # type: ignore\n", "\n", - "objective_target = TargetRegistry.get_registry_singleton().get_instance_by_name(\"openai_chat\")\n", + "objective_target = TargetRegistry.get_registry_singleton().instances.get(\"openai_chat\")\n", "printer = ConsoleScenarioResultPrinter()" ] }, diff --git a/doc/code/scenarios/3_adaptive_scenarios.py b/doc/code/scenarios/3_adaptive_scenarios.py index 6239a24077..5466286b78 100644 --- a/doc/code/scenarios/3_adaptive_scenarios.py +++ b/doc/code/scenarios/3_adaptive_scenarios.py @@ -54,7 +54,7 @@ await initialize_from_config_async(config_path=Path("../../scanner/pyrit_conf.yaml")) # type: ignore -objective_target = TargetRegistry.get_registry_singleton().get_instance_by_name("openai_chat") +objective_target = TargetRegistry.get_registry_singleton().instances.get("openai_chat") printer = ConsoleScenarioResultPrinter() # %% [markdown] diff --git a/doc/scanner/foundry.ipynb b/doc/scanner/foundry.ipynb index 9035823c9e..f7cfec295e 100644 --- a/doc/scanner/foundry.ipynb +++ b/doc/scanner/foundry.ipynb @@ -50,7 +50,7 @@ "\n", "await initialize_from_config_async(config_path=Path(\"pyrit_conf.yaml\")) # type: ignore\n", "\n", - "objective_target = TargetRegistry.get_registry_singleton().get_instance_by_name(\"openai_chat\")" + "objective_target = TargetRegistry.get_registry_singleton().instances.get(\"openai_chat\")" ] }, { diff --git a/doc/scanner/foundry.py b/doc/scanner/foundry.py index d4f9764c3a..4bb046133a 100644 --- a/doc/scanner/foundry.py +++ b/doc/scanner/foundry.py @@ -30,7 +30,7 @@ await initialize_from_config_async(config_path=Path("pyrit_conf.yaml")) # type: ignore -objective_target = TargetRegistry.get_registry_singleton().get_instance_by_name("openai_chat") +objective_target = TargetRegistry.get_registry_singleton().instances.get("openai_chat") # %% [markdown] # ## RedTeamAgent # diff --git a/doc/scanner/garak.ipynb b/doc/scanner/garak.ipynb index 6b9d68b91b..19a07c13fc 100644 --- a/doc/scanner/garak.ipynb +++ b/doc/scanner/garak.ipynb @@ -49,7 +49,7 @@ "\n", "await initialize_from_config_async(config_path=Path(\"pyrit_conf.yaml\")) # type: ignore\n", "\n", - "objective_target = TargetRegistry.get_registry_singleton().get_instance_by_name(\"openai_chat\")" + "objective_target = TargetRegistry.get_registry_singleton().instances.get(\"openai_chat\")" ] }, { diff --git a/doc/scanner/garak.py b/doc/scanner/garak.py index e86c03146f..a78d58dabf 100644 --- a/doc/scanner/garak.py +++ b/doc/scanner/garak.py @@ -29,7 +29,7 @@ await initialize_from_config_async(config_path=Path("pyrit_conf.yaml")) # type: ignore -objective_target = TargetRegistry.get_registry_singleton().get_instance_by_name("openai_chat") +objective_target = TargetRegistry.get_registry_singleton().instances.get("openai_chat") # %% [markdown] # ## Encoding # diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index b4997828a0..ea6ec134bd 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -243,9 +243,9 @@ def _resolve_target(self, *, request: RunScenarioRequest) -> "PromptTarget": ValueError: If the target is not found in the registry. """ target_registry = TargetRegistry.get_registry_singleton() - objective_target = target_registry.get_instance_by_name(request.target_name) + objective_target = target_registry.instances.get(request.target_name) if objective_target is None: - available_names = target_registry.get_names() + available_names = target_registry.instances.get_names() if not available_names: raise ValueError( f"Target '{request.target_name}' not found. The target registry is empty. " diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index 6663dfa57b..27248d9a3d 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -31,7 +31,7 @@ from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget from pyrit.prompt_target.openai.openai_target import OpenAITarget from pyrit.prompt_target.round_robin_target import RoundRobinTarget -from pyrit.registry.object_registries import TargetRegistry +from pyrit.registry import TargetRegistry logger = logging.getLogger(__name__) @@ -192,7 +192,7 @@ async def list_targets_async( """ items = [ self._build_instance_from_object(target_registry_name=entry.name, target_obj=entry.instance) - for entry in self._registry.get_all_instances() + for entry in self._registry.instances.get_all_instances() ] page, has_more = self._paginate(items=items, cursor=cursor, limit=limit) next_cursor = page[-1].target_registry_name if has_more and page else None @@ -227,7 +227,7 @@ async def get_target_async(self, *, target_registry_name: str) -> TargetInstance Returns: TargetInstance if found, None otherwise. """ - obj = self._registry.get_instance_by_name(target_registry_name) + obj = self._registry.instances.get(target_registry_name) if obj is None: return None return self._build_instance_from_object(target_registry_name=target_registry_name, target_obj=obj) @@ -239,7 +239,7 @@ def get_target_object(self, *, target_registry_name: str) -> Any | None: Returns: The PromptTarget object if found, None otherwise. """ - return self._registry.get_instance_by_name(target_registry_name) + return self._registry.instances.get(target_registry_name) async def create_target_async(self, *, request: CreateTargetRequest) -> TargetInstance: """ @@ -281,7 +281,7 @@ async def create_target_async(self, *, request: CreateTargetRequest) -> TargetIn target_obj = target_class(**params) - self._registry.register_instance(target_obj) + self._registry.instances.register(target_obj) target_registry_name = target_obj.get_identifier().unique_name return self._build_instance_from_object(target_registry_name=target_registry_name, target_obj=target_obj) @@ -334,7 +334,7 @@ def _create_round_robin_target(self, *, params: dict[str, Any]) -> RoundRobinTar resolved_weights: list[int] = [] duplicates: list[str] = [] for idx, name in enumerate(registry_names): - target_obj = self._registry.get_instance_by_name(name) + target_obj = self._registry.instances.get(name) if target_obj is None: raise ValueError(f"Target '{name}' not found in the registry.") target_hash = target_obj.get_identifier().hash diff --git a/pyrit/models/identifiers/scorer_identifier.py b/pyrit/models/identifiers/scorer_identifier.py index 8912230a27..08ce2b7858 100644 --- a/pyrit/models/identifiers/scorer_identifier.py +++ b/pyrit/models/identifiers/scorer_identifier.py @@ -11,6 +11,7 @@ from pyrit.models.identifiers.component_identifier import ComponentIdentifier from pyrit.models.identifiers.evaluation_markers import Evaluate +from pyrit.models.identifiers.param_markers import Param from pyrit.models.identifiers.target_identifier import ( # noqa: TC001 TargetIdentifier, # runtime-required by Pydantic field annotations ) @@ -24,6 +25,12 @@ class ScorerIdentifier(ComponentIdentifier): Promotes the ``scorer_type`` discriminator, the ``score_aggregator`` name, and the scorer's own child slots — ``prompt_target`` (an LLM target) and ``sub_scorers`` (nested scorers). + + Build markers (``Param.*``) declare how the child slots map to the scorer's + constructor: ``prompt_target`` is an included parameter aliased to the + ``chat_target`` constructor arg, and ``sub_scorers`` is an included parameter + aliased to the composite scorer's ``scorers`` arg. Their identifier types make + them references resolved by name from the target and scorer registries. """ component_type: ClassVar[ComponentType] = ComponentType.SCORER @@ -32,7 +39,11 @@ class ScorerIdentifier(ComponentIdentifier): scorer_type: Annotated[str | None, Evaluate.Include()] = None #: Name of the aggregator function combining sub-scores (e.g., ``"AND_"``). score_aggregator: Annotated[str | None, Evaluate.Include()] = None - #: Target an LLM-backed scorer calls (e.g., ``SelfAskScaleScorer``). - prompt_target: Annotated[TargetIdentifier | None, Evaluate.Include()] = None - #: Nested scorers a composite wraps, typed recursively. - sub_scorers: Annotated[list[ScorerIdentifier], Evaluate.Include()] = Field(default_factory=list) + #: Target an LLM-backed scorer calls (e.g., ``SelfAskScaleScorer``). The + #: constructor arg is ``chat_target``, so the build marker aliases it. + prompt_target: Annotated[TargetIdentifier | None, Evaluate.Include(), Param.Include(alias="chat_target")] = None + #: Nested scorers a composite wraps, typed recursively. The composite + #: constructor arg is ``scorers`` (a list), so the build marker aliases it. + sub_scorers: Annotated[list[ScorerIdentifier], Evaluate.Include(), Param.Include(alias="scorers")] = Field( + default_factory=list + ) diff --git a/pyrit/models/identifiers/target_identifier.py b/pyrit/models/identifiers/target_identifier.py index c5a910def2..2070322c93 100644 --- a/pyrit/models/identifiers/target_identifier.py +++ b/pyrit/models/identifiers/target_identifier.py @@ -11,6 +11,7 @@ from pyrit.models.identifiers.component_identifier import ComponentIdentifier from pyrit.models.identifiers.evaluation_markers import Evaluate +from pyrit.models.identifiers.param_markers import Param from pyrit.models.parameter import ComponentType @@ -46,5 +47,7 @@ class TargetIdentifier(ComponentIdentifier): top_p: Annotated[float | None, Evaluate.Include()] = None #: Maximum requests per minute. max_requests_per_minute: Annotated[int | None, Evaluate.Exclude()] = None - #: Inner targets of a multi-target (e.g., ``RoundRobinTarget``), typed recursively. - targets: Annotated[list[TargetIdentifier], Evaluate.Unwrap()] = Field(default_factory=list) + #: Inner targets of a multi-target (e.g., ``RoundRobinTarget``), typed + #: recursively. An included constructor parameter (the ctor arg is also + #: ``targets``, a list) resolved by name from the target registry. + targets: Annotated[list[TargetIdentifier], Evaluate.Unwrap(), Param.Include()] = Field(default_factory=list) diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index b6cb751387..ece339a18b 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -16,6 +16,10 @@ from pyrit.registry.components import ( ConverterMetadata, ConverterRegistry, + ScorerMetadata, + ScorerRegistry, + TargetMetadata, + TargetRegistry, ) from pyrit.registry.discovery import ( discover_in_directory, @@ -31,9 +35,6 @@ AttackTechniqueRegistry, BaseInstanceRegistry, RegistryEntry, - RetrievableInstanceRegistry, - ScorerRegistry, - TargetRegistry, ) from pyrit.registry.registry import Registry from pyrit.registry.tag_query import TagQuery @@ -47,7 +48,6 @@ "DefaultInstanceRegistry", "InstanceRegistry", "Registry", - "RetrievableInstanceRegistry", "SupportsInstances", "ClassEntry", "discover_in_directory", @@ -61,6 +61,8 @@ "ScenarioParameterMetadata", "ScenarioRegistry", "ScorerRegistry", + "ScorerMetadata", "TargetRegistry", + "TargetMetadata", "TagQuery", ] diff --git a/pyrit/registry/components/__init__.py b/pyrit/registry/components/__init__.py index 38faacaeda..fa0eb5d5d8 100644 --- a/pyrit/registry/components/__init__.py +++ b/pyrit/registry/components/__init__.py @@ -18,8 +18,20 @@ ConverterMetadata, ConverterRegistry, ) +from pyrit.registry.components.scorer_registry import ( + ScorerMetadata, + ScorerRegistry, +) +from pyrit.registry.components.target_registry import ( + TargetMetadata, + TargetRegistry, +) __all__ = [ "ConverterRegistry", "ConverterMetadata", + "ScorerRegistry", + "ScorerMetadata", + "TargetRegistry", + "TargetMetadata", ] diff --git a/pyrit/registry/components/converter_registry.py b/pyrit/registry/components/converter_registry.py index 0e224ea17d..eb8bf68964 100644 --- a/pyrit/registry/components/converter_registry.py +++ b/pyrit/registry/components/converter_registry.py @@ -23,7 +23,6 @@ from __future__ import annotations -import logging from dataclasses import dataclass from typing import TYPE_CHECKING @@ -34,26 +33,10 @@ from pyrit.registry.registry import Registry if TYPE_CHECKING: - from pyrit.prompt_converter import PromptConverter - -logger = logging.getLogger(__name__) - - -def _prompt_converter_type() -> type[PromptConverter]: - """ - Return the ``PromptConverter`` base class, importing it lazily. - - Used as the ``instance_type`` for the registry's ``instances`` container so - a non-converter cannot be registered, without importing the converter - package at module load (which would defeat lazy discovery). + from types import ModuleType - Returns: - type[PromptConverter]: The ``PromptConverter`` base class. - """ from pyrit.prompt_converter import PromptConverter - return PromptConverter - @dataclass(frozen=True) class ConverterMetadata(ClassRegistryEntry): @@ -104,49 +87,30 @@ class ConverterRegistry(Registry["PromptConverter", ConverterMetadata]): def __init__(self, *, lazy_discovery: bool = True) -> None: """ - Initialize the registry. + Initialize the registry and its typed ``instances`` container. Args: lazy_discovery (bool): If True, class discovery is deferred until first access. If False, discovery runs immediately. """ super().__init__(lazy_discovery=lazy_discovery) - self.instances: InstanceRegistry[PromptConverter] = DefaultInstanceRegistry( - instance_type=_prompt_converter_type - ) + self.instances: InstanceRegistry[PromptConverter] = DefaultInstanceRegistry(instance_type=self._base_type) - def _identifier_type(self) -> type[ConverterIdentifier]: - """Return ``ConverterIdentifier`` so its ``Param.*`` markers drive derivation.""" - return ConverterIdentifier - - def _get_registry_name(self, cls: type[PromptConverter]) -> str: - """ - Use the exact class name as the catalog key. - - Converters are referenced by their class name (e.g. ``"Base64Converter"``) - rather than the snake_case default used by other class registries. + def _base_type(self) -> type[PromptConverter]: + """Return the ``PromptConverter`` base class, imported lazily.""" + from pyrit.prompt_converter import PromptConverter - Returns: - str: The class name. - """ - return cls.__name__ + return PromptConverter - def _discover(self) -> None: - """Discover all concrete ``PromptConverter`` subclasses from ``pyrit.prompt_converter``.""" + def _discovery_package(self) -> ModuleType: + """Return the ``pyrit.prompt_converter`` package scanned for converter classes.""" from pyrit import prompt_converter - from pyrit.prompt_converter import PromptConverter - for name in prompt_converter.__all__: - cls = getattr(prompt_converter, name, None) - if cls is None or not isinstance(cls, type): - continue - if not issubclass(cls, PromptConverter) or cls is PromptConverter: - continue - # Key off the class itself (via _get_registry_name) rather than the - # __all__ export name so the catalog key always matches class_name, - # even if an export is ever aliased. - self.register_class(cls) - logger.debug(f"Registered converter class: {cls.__name__}") + return prompt_converter + + def _identifier_type(self) -> type[ConverterIdentifier]: + """Return ``ConverterIdentifier`` so its ``Param.*`` markers drive derivation.""" + return ConverterIdentifier def _metadata_class(self) -> type[ConverterMetadata]: """Return ``ConverterMetadata``; the base populates it from the common fields.""" diff --git a/pyrit/registry/components/scorer_registry.py b/pyrit/registry/components/scorer_registry.py new file mode 100644 index 0000000000..84f195091e --- /dev/null +++ b/pyrit/registry/components/scorer_registry.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scorer registry for PyRIT. + +A single registry for ``Scorer`` that both: + +- **builds** scorers from a type name plus arguments — discovering scorer classes, + deriving their ``Parameter`` contract from the constructor enriched by + ``ScorerIdentifier``'s build markers, and constructing instances via the shared + resolver (so an LLM scorer can be built by passing a ``chat_target`` registry + name, and a composite scorer by passing a list of ``scorers`` registry names), + and +- **holds** pre-configured scorer instances registered via initializers or the + backend. + +It is a ``Registry``: the registry's own surface (``get_class``, +``get_class_names``, ``get_all_registered_class_metadata``, ``create_instance``) +is the buildable class catalog. Pre-configured instances live under the +``instances`` property (``register``, ``get``, ``get_all_instances``, +``get_names``), a ``DefaultInstanceRegistry``. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from pyrit.models.identifiers import ScorerIdentifier +from pyrit.models.parameter import ComponentType +from pyrit.registry.base import ClassRegistryEntry +from pyrit.registry.instance_registry import DefaultInstanceRegistry, InstanceRegistry +from pyrit.registry.registry import Registry + +if TYPE_CHECKING: + from types import ModuleType + + from pyrit.score.scorer import Scorer + + +@dataclass(frozen=True) +class ScorerMetadata(ClassRegistryEntry): + """ + Metadata describing a registered ``Scorer`` class. + + Carries the derived ``parameters`` build contract (the same list the resolver + consumes to build an instance). Whether the scorer is LLM-based is projected + from that contract rather than stored, so the entry can never drift from the + class. + + Use ``ScorerRegistry.get_class()`` to get the actual class or + ``create_instance()`` to build a configured instance. + """ + + @property + def is_llm_based(self) -> bool: + """Whether the scorer requires an LLM target (a TARGET reference parameter).""" + return any(p.is_reference_to(ComponentType.TARGET) for p in self.parameters) + + +class ScorerRegistry(Registry["Scorer", ScorerMetadata]): + """ + Registry that discovers, builds, and holds ``Scorer`` instances. + + Discovers all concrete ``Scorer`` subclasses exported from ``pyrit.score`` + (keyed by their exact class name, e.g. ``"SelfAskRefusalScorer"``) for the + buildable catalog. Pre-configured instances registered via initializers or the + backend are held under the ``instances`` property. + + Building a scorer resolves its arguments through the shared resolver, so LLM + scorers can be constructed by passing a ``chat_target`` that names a target in + the ``TargetRegistry``, and composite scorers by passing a list of ``scorers`` + that name scorers already held under ``instances``. + """ + + def __init__(self, *, lazy_discovery: bool = True) -> None: + """ + Initialize the registry and its typed ``instances`` container. + + Args: + lazy_discovery (bool): If True, class discovery is deferred until first + access. If False, discovery runs immediately. + """ + super().__init__(lazy_discovery=lazy_discovery) + self.instances: InstanceRegistry[Scorer] = DefaultInstanceRegistry(instance_type=self._base_type) + + def _base_type(self) -> type[Scorer]: + """Return the ``Scorer`` base class, imported lazily.""" + from pyrit.score.scorer import Scorer + + return Scorer + + def _discovery_package(self) -> ModuleType: + """Return the ``pyrit.score`` package scanned for scorer classes.""" + from pyrit import score + + return score + + def _identifier_type(self) -> type[ScorerIdentifier]: + """Return ``ScorerIdentifier`` so its ``Param.*`` markers drive derivation.""" + return ScorerIdentifier + + def _metadata_class(self) -> type[ScorerMetadata]: + """Return ``ScorerMetadata``; the base populates it from the common fields.""" + return ScorerMetadata diff --git a/pyrit/registry/components/target_registry.py b/pyrit/registry/components/target_registry.py new file mode 100644 index 0000000000..01e3c4e4dd --- /dev/null +++ b/pyrit/registry/components/target_registry.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Target registry for PyRIT. + +A single registry for ``PromptTarget`` that both: + +- **builds** targets from a type name plus arguments — discovering target classes, + deriving their ``Parameter`` contract from the constructor enriched by + ``TargetIdentifier``'s build markers, and constructing instances via the shared + resolver (so a multi-target such as ``RoundRobinTarget`` can be built by passing + a list of ``targets`` registry names), and +- **holds** pre-configured target instances registered via initializers or the + backend. + +It is a ``Registry``: the registry's own surface (``get_class``, +``get_class_names``, ``get_all_registered_class_metadata``, ``create_instance``) +is the buildable class catalog. Pre-configured instances live under the +``instances`` property (``register``, ``get``, ``get_all_instances``, +``get_names``), a ``DefaultInstanceRegistry``. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from pyrit.models.identifiers import TargetIdentifier +from pyrit.registry.base import ClassRegistryEntry +from pyrit.registry.instance_registry import DefaultInstanceRegistry, InstanceRegistry +from pyrit.registry.registry import Registry + +if TYPE_CHECKING: + from types import ModuleType + + from pyrit.prompt_target import PromptTarget + + +@dataclass(frozen=True) +class TargetMetadata(ClassRegistryEntry): + """ + Metadata describing a registered ``PromptTarget`` class. + + Carries the derived ``parameters`` build contract (the same list the resolver + consumes to build an instance). Use ``TargetRegistry.get_class()`` to get the + actual class or ``create_instance()`` to build a configured instance. + """ + + +class TargetRegistry(Registry["PromptTarget", TargetMetadata]): + """ + Registry that discovers, builds, and holds ``PromptTarget`` instances. + + Discovers all concrete ``PromptTarget`` subclasses exported from + ``pyrit.prompt_target`` (keyed by their exact class name, e.g. + ``"OpenAIChatTarget"``) for the buildable catalog. Pre-configured instances + registered via initializers or the backend are held under the ``instances`` + property. + + Building a multi-target resolves its arguments through the shared resolver, so + a ``RoundRobinTarget`` can be constructed by passing a list of ``targets`` that + name targets already held under ``instances``. + """ + + def __init__(self, *, lazy_discovery: bool = True) -> None: + """ + Initialize the registry and its typed ``instances`` container. + + Args: + lazy_discovery (bool): If True, class discovery is deferred until first + access. If False, discovery runs immediately. + """ + super().__init__(lazy_discovery=lazy_discovery) + self.instances: InstanceRegistry[PromptTarget] = DefaultInstanceRegistry(instance_type=self._base_type) + + def _base_type(self) -> type[PromptTarget]: + """Return the ``PromptTarget`` base class, imported lazily.""" + from pyrit.prompt_target import PromptTarget + + return PromptTarget + + def _discovery_package(self) -> ModuleType: + """Return the ``pyrit.prompt_target`` package scanned for target classes.""" + from pyrit import prompt_target + + return prompt_target + + def _identifier_type(self) -> type[TargetIdentifier]: + """Return ``TargetIdentifier`` so its ``Param.*`` markers drive derivation.""" + return TargetIdentifier + + def _metadata_class(self) -> type[TargetMetadata]: + """Return ``TargetMetadata``; the base populates it from the common fields.""" + return TargetMetadata diff --git a/pyrit/registry/object_registries/__init__.py b/pyrit/registry/object_registries/__init__.py index 9694f12bef..7608ed4bd9 100644 --- a/pyrit/registry/object_registries/__init__.py +++ b/pyrit/registry/object_registries/__init__.py @@ -4,11 +4,11 @@ """ Object registries package. -This package contains registries that store pre-configured instances (not classes). -Examples include ScorerRegistry which stores Scorer instances that have been -initialized with their required parameters (e.g., chat_target). - -For registries that store classes (type[T]), see class_registries/. +This package contains the legacy instance-only registry stack still used by +``AttackTechniqueRegistry``. Component registries that hold pre-configured +instances (converters, scorers, targets) now live in ``registry/components/`` as +``Registry`` subclasses that expose their instances via the ``.instances`` +property. """ from pyrit.registry.object_registries.attack_technique_registry import ( @@ -18,23 +18,11 @@ BaseInstanceRegistry, RegistryEntry, ) -from pyrit.registry.object_registries.retrievable_instance_registry import ( - RetrievableInstanceRegistry, -) -from pyrit.registry.object_registries.scorer_registry import ( - ScorerRegistry, -) -from pyrit.registry.object_registries.target_registry import ( - TargetRegistry, -) __all__ = [ # Base classes "BaseInstanceRegistry", - "RetrievableInstanceRegistry", "RegistryEntry", # Concrete registries "AttackTechniqueRegistry", - "ScorerRegistry", - "TargetRegistry", ] diff --git a/pyrit/registry/object_registries/base_instance_registry.py b/pyrit/registry/object_registries/base_instance_registry.py index 58a7aa2354..73a6ba9393 100644 --- a/pyrit/registry/object_registries/base_instance_registry.py +++ b/pyrit/registry/object_registries/base_instance_registry.py @@ -10,19 +10,18 @@ registries should subclass ``Registry`` (a class catalog that can build instances by name) and hold pre-configured instances via the ``.instances`` property (a ``DefaultInstanceRegistry``). See - ``ConverterRegistry`` for the target shape. This class and - ``RetrievableInstanceRegistry`` remain only because ``TargetRegistry``, - ``ScorerRegistry``, and ``AttackTechniqueRegistry`` still subclass them; - the whole stack is removed once those migrate. + ``ConverterRegistry`` for the target shape. This class remains only + because ``AttackTechniqueRegistry`` still subclasses it; it is removed + once that migrates. This module provides ``BaseInstanceRegistry``, the shared infrastructure for registries that store ``Identifiable`` objects (not classes): singleton lifecycle, registration, tags, metadata, container protocol. Subclass directly for registries that store factories or other -non-retrievable items (e.g., ``AttackTechniqueRegistry``). For registries -where callers retrieve stored objects directly, subclass -``RetrievableInstanceRegistry`` instead. +non-retrievable items (e.g., ``AttackTechniqueRegistry``). For registries +where callers retrieve stored objects directly, use ``Registry`` + the +``.instances`` property (``DefaultInstanceRegistry``) instead. For registries that store classes (type[T]), see ``class_registries/``. """ @@ -59,9 +58,8 @@ class BaseInstanceRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T via the ``.instances`` property (``DefaultInstanceRegistry``), which carries this same surface (``register``/``get``/``get_by_tag``/ ``add_tags``/``find_dependents_of_tag``/``list_metadata``). This class - survives only for the not-yet-migrated ``TargetRegistry``, - ``ScorerRegistry``, and ``AttackTechniqueRegistry`` and is removed once - they move to ``.instances``. + survives only for the not-yet-migrated ``AttackTechniqueRegistry`` and + is removed once it moves to ``.instances``. Provides singleton lifecycle, registration, tag-based lookup, metadata filtering, and the standard container protocol (``__contains__``, @@ -69,7 +67,8 @@ class BaseInstanceRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T Subclass directly when stored items should not be retrievable via ``get()`` (e.g., factory registries). For registries that expose - direct item retrieval, subclass ``RetrievableInstanceRegistry`` instead. + direct item retrieval, use ``Registry`` + the ``.instances`` property + (``DefaultInstanceRegistry``) instead. All stored items must implement ``Identifiable``, which provides ``get_identifier()`` for metadata generation. diff --git a/pyrit/registry/object_registries/retrievable_instance_registry.py b/pyrit/registry/object_registries/retrievable_instance_registry.py deleted file mode 100644 index b462c22012..0000000000 --- a/pyrit/registry/object_registries/retrievable_instance_registry.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Retrievable instance registry for PyRIT. - -.. note:: - - **Legacy stack — do not build new registries on this.** New component - registries subclass ``Registry`` and retain instances via the - ``.instances`` property (``DefaultInstanceRegistry``), which already - provides ``get``/``get_entry``/``get_all_instances``. See - ``ConverterRegistry`` for the target shape. This class remains only for the - not-yet-migrated ``ScorerRegistry`` and ``TargetRegistry`` and is removed - once they migrate. - -This module provides ``RetrievableInstanceRegistry``, which extends -``BaseInstanceRegistry`` with ``get()``, ``get_entry()``, and -``get_all_instances()`` for registries where callers retrieve stored -objects directly (e.g., ``ScorerRegistry``, ``TargetRegistry``). - -For the shared base class, see ``base_instance_registry``. -For registries that store classes (type[T]), see ``class_registries/``. -""" - -from __future__ import annotations - -from pyrit.registry.object_registries.base_instance_registry import ( - BaseInstanceRegistry, - RegistryEntry, - T, -) - -# Re-export so existing ``from retrievable_instance_registry import ...`` still works -__all__ = ["RetrievableInstanceRegistry", "BaseInstanceRegistry", "RegistryEntry"] - - -class RetrievableInstanceRegistry(BaseInstanceRegistry[T]): - """ - Base class for registries that store directly-retrievable instances. - - .. note:: - - **Legacy — do not subclass for new registries.** Use - ``Registry`` + the ``.instances`` property - (``DefaultInstanceRegistry``), which already exposes - ``get``/``get_entry``/``get_all_instances``. Retained only for the - not-yet-migrated ``ScorerRegistry`` and ``TargetRegistry``. - - Extends ``BaseInstanceRegistry`` with ``get()``, ``get_entry()``, and - ``get_all_instances()`` for registries where callers retrieve the - stored objects directly (e.g., scorers, converters, targets). - - For registries that store factories or other non-retrievable items, - subclass ``BaseInstanceRegistry`` directly instead. - - Type Parameters: - T: The type of instances stored in the registry (must be Identifiable). - """ - - def get(self, name: str) -> T | None: - """ - Get a registered instance by name. - - Args: - name: The registry name of the instance. - - Returns: - The instance, or None if not found. - """ - entry = self._registry_items.get(name) - if entry is None: - return None - return entry.instance - - def get_entry(self, name: str) -> RegistryEntry[T] | None: - """ - Get a full registry entry by name, including tags. - - Args: - name: The registry name of the entry. - - Returns: - The RegistryEntry, or None if not found. - """ - return self._registry_items.get(name) - - def get_all_instances(self) -> list[RegistryEntry[T]]: - """ - Get all registered entries sorted by name. - - Returns: - List of RegistryEntry objects sorted by name. - """ - return [self._registry_items[name] for name in sorted(self._registry_items.keys())] diff --git a/pyrit/registry/object_registries/scorer_registry.py b/pyrit/registry/object_registries/scorer_registry.py deleted file mode 100644 index d1a938aa30..0000000000 --- a/pyrit/registry/object_registries/scorer_registry.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Scorer registry for discovering and managing PyRIT scorers. - -Scorers are registered explicitly via initializers as pre-configured instances. -""" - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING - -from pyrit.registry.object_registries.retrievable_instance_registry import ( - RetrievableInstanceRegistry, -) - -if TYPE_CHECKING: - from pyrit.score.scorer import Scorer - -logger = logging.getLogger(__name__) - - -class ScorerRegistry(RetrievableInstanceRegistry["Scorer"]): - """ - Registry for managing available scorer instances. - - This registry stores pre-configured Scorer instances (not classes). - Scorers are registered explicitly via initializers after being instantiated - with their required parameters (e.g., chat_target). - - Scorers are identified by their snake_case name derived from the class name, - or a custom name provided during registration. - """ - - def register_instance( - self, - scorer: Scorer, - *, - name: str | None = None, - tags: dict[str, str] | list[str] | None = None, - ) -> None: - """ - Register a scorer instance. - - Note: Unlike ScenarioRegistry and InitializerRegistry which register classes, - ScorerRegistry registers pre-configured instances. - - Args: - scorer: The pre-configured scorer instance (not a class). - name: Optional custom registry name. If not provided, - derived from the scorer's unique identifier. - tags: Optional tags for categorisation. Accepts a ``dict[str, str]`` - or a ``list[str]`` (each string becomes a key with value ``""``). - """ - if name is None: - name = scorer.get_identifier().unique_name - - self.register(scorer, name=name, tags=tags) - logger.debug(f"Registered scorer instance: {name} ({scorer.__class__.__name__})") - - def get_instance_by_name(self, name: str) -> Scorer | None: - """ - Get a registered scorer instance by name. - - Note: This returns an already-instantiated scorer, not a class. - - Args: - name: The registry name of the scorer. - - Returns: - The scorer instance, or None if not found. - """ - return self.get(name) diff --git a/pyrit/registry/object_registries/target_registry.py b/pyrit/registry/object_registries/target_registry.py deleted file mode 100644 index 170bad2078..0000000000 --- a/pyrit/registry/object_registries/target_registry.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Target registry for discovering and managing PyRIT prompt targets. - -Targets are registered explicitly via initializers as pre-configured instances. -""" - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING - -from pyrit.registry.object_registries.retrievable_instance_registry import ( - RetrievableInstanceRegistry, -) - -if TYPE_CHECKING: - from pyrit.prompt_target import PromptTarget - -logger = logging.getLogger(__name__) - - -class TargetRegistry(RetrievableInstanceRegistry["PromptTarget"]): - """ - Registry for managing available prompt target instances. - - This registry stores pre-configured PromptTarget instances (not classes). - Targets are registered explicitly via initializers after being instantiated - with their required parameters (e.g., endpoint, API keys). - - Targets are identified by their snake_case name derived from the class name, - or a custom name provided during registration. - """ - - def register_instance( - self, - target: PromptTarget, - *, - name: str | None = None, - tags: dict[str, str] | list[str] | None = None, - ) -> None: - """ - Register a target instance. - - Note: Unlike ScenarioRegistry and InitializerRegistry which register classes, - TargetRegistry registers pre-configured instances. - - Args: - target: The pre-configured target instance (not a class). - name: Optional custom registry name. If not provided, - derived from class name with identifier hash appended - (e.g., OpenAIChatTarget -> openai_chat_abc123). - tags: Optional tags for categorization. Accepts a ``dict[str, str]`` - or a ``list[str]`` (each string becomes a key with value ``""``). - """ - if name is None: - name = target.get_identifier().unique_name - - self.register(target, name=name, tags=tags) - logger.debug(f"Registered target instance: {name} ({target.__class__.__name__})") - - def get_instance_by_name(self, name: str) -> PromptTarget | None: - """ - Get a registered target instance by name. - - Note: This returns an already-instantiated target, not a class. - - Args: - name: The registry name of the target. - - Returns: - The target instance, or None if not found. - """ - return self.get(name) diff --git a/pyrit/registry/registry.py b/pyrit/registry/registry.py index c2a9ff3c60..b1ba20992f 100644 --- a/pyrit/registry/registry.py +++ b/pyrit/registry/registry.py @@ -26,24 +26,27 @@ from __future__ import annotations +import inspect +import logging from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Generic, TypeVar -from pyrit.models import class_name_to_snake_case from pyrit.registry.base import ClassRegistryEntry from pyrit.registry.resolution import ( derive_parameters, - is_component_type_resolvable, resolve_constructor_args, ) if TYPE_CHECKING: from collections.abc import Iterator, Mapping + from types import ModuleType from typing_extensions import Self from pyrit.models.identifiers.component_identifier import ComponentIdentifier - from pyrit.models.parameter import Parameter + from pyrit.models.parameter import ComponentType, Parameter + +logger = logging.getLogger(__name__) T = TypeVar("T") MetadataT = TypeVar("MetadataT", bound=ClassRegistryEntry) @@ -138,11 +141,19 @@ class Registry(ABC, Generic[T, MetadataT]): registry-reference parameters are resolved by name from the owning domain. - Singleton support via ``get_registry_singleton()``. - Subclasses must implement: + Subclasses provide the domain specifics: - - ``_discover()`` — populate the catalog by calling ``register_class`` for each class. + - ``_base_type()`` — the base class to discover (and the type the optional + ``instances`` container is constrained to), imported lazily. + - ``_discovery_package()`` — the package whose ``__all__`` is scanned for + concrete subclasses of ``_base_type()``. - ``_metadata_class()`` — return the concrete metadata dataclass the base builds. + The default ``_discover()`` scans ``_discovery_package().__all__`` for concrete + ``_base_type()`` subclasses and registers each by class name. A registry whose + discovery is genuinely different (e.g. a directory or filesystem scan) overrides + ``_discover()`` instead of supplying the two hooks. + Type Parameters: T: The type of classes being registered (e.g. ``PromptConverter``). MetadataT: The metadata dataclass type (e.g. ``ConverterMetadata``). @@ -198,14 +209,65 @@ def _ensure_discovered(self) -> None: self._discover() self._discovered = True - @abstractmethod + def _base_type(self) -> type[T]: + """ + Return the domain base class to discover (e.g. ``PromptTarget``), imported lazily. + + Used by the default ``_discover`` to filter the package's exports, and by + instance-holding registries to constrain their ``instances`` container. + Importing lazily keeps the heavy domain package out of module load so the + registry's lazy discovery is preserved. + + Returns: + type[T]: The domain base class. + + Raises: + NotImplementedError: If neither ``_base_type`` nor ``_discover`` is overridden. + """ + raise NotImplementedError( + f"{type(self).__name__} must implement _base_type()/_discovery_package() or override _discover()." + ) + + def _discovery_package(self) -> ModuleType: + """ + Return the package whose ``__all__`` the default ``_discover`` scans. + + Returns: + ModuleType: The domain package (e.g. ``pyrit.prompt_target``). + + Raises: + NotImplementedError: If neither ``_discovery_package`` nor ``_discover`` is overridden. + """ + raise NotImplementedError( + f"{type(self).__name__} must implement _base_type()/_discovery_package() or override _discover()." + ) + def _discover(self) -> None: """ - Perform discovery of registry classes. + Populate the catalog from the domain package. - Subclasses implement this to populate the catalog by calling - ``self.register_class(cls)`` for each discovered class. + Scans ``_discovery_package().__all__`` and registers every concrete subclass + of ``_base_type()`` (skipping the base itself and abstract classes), keyed by + class name via ``register_class``. Registries with bespoke discovery override + this method instead of supplying ``_base_type``/``_discovery_package``. """ + package = self._discovery_package() + base = self._base_type() + for name in getattr(package, "__all__", []): + cls = getattr(package, name, None) + if cls is None or not isinstance(cls, type): + continue + # Guard against entries that aren't genuine classes. A test elsewhere in the + # suite may patch a package export with a mock (e.g. ``autospec``/``spec=type``) + # that reports ``isinstance(cls, type) is True`` yet makes ``issubclass`` raise + # ``TypeError``; skip anything that isn't a real subclass of the base. + try: + if not issubclass(cls, base) or cls is base or inspect.isabstract(cls): + continue + except TypeError: + continue + self.register_class(cls) + logger.debug(f"Registered {base.__name__} class: {cls.__name__}") @abstractmethod def _metadata_class(self) -> type[MetadataT]: @@ -292,16 +354,17 @@ def _get_registry_name(self, cls: type[T]) -> str: """ Get the catalog name for a class. - Subclasses can override this to customize name derivation. The default - converts CamelCase to snake_case. + Component classes are referenced by their exact class name (e.g. + ``"OpenAIChatTarget"``). Registries whose names follow a different scheme + (e.g. snake_case filenames or dotted paths) override this. Args: cls (type[T]): The class to get a name for. Returns: - str: The catalog name (snake_case identifier by default). + str: The class name. """ - return class_name_to_snake_case(cls.__name__) + return cls.__name__ def _validate_class(self, cls: type[T]) -> None: """ @@ -326,12 +389,31 @@ class whose build contract does not line up with a resolvable reference # cheap, so the small duplication is deliberate rather than worth caching. parameters = self._derive_parameters(cls) for param in parameters: - if param.reference is not None and not is_component_type_resolvable(param.reference.component_type): + if param.reference is not None and not self._is_component_type_resolvable(param.reference.component_type): raise ValueError( f"{cls.__name__}: reference parameter '{param.name}' has no registry wired for component type " f"'{param.reference.component_type}'." ) + @staticmethod + def _is_component_type_resolvable(component_type: ComponentType) -> bool: + """ + Return whether a registry is wired to resolve references of ``component_type``. + + This is the registration-time gate: a reference parameter whose component + type has no paired registry can never be resolved by name and should fail + fast at ``register_class`` time instead of erroring only at build time. + + Args: + component_type (ComponentType): The referenced component family. + + Returns: + bool: True when references of ``component_type`` can be resolved by name. + """ + from pyrit.registry.resolution import _registry_getter_for_component_type + + return _registry_getter_for_component_type(component_type) is not None + def register_class(self, cls: type[T], *, name: str | None = None) -> None: """ Add a class to the catalog after validating it. diff --git a/pyrit/registry/resolution.py b/pyrit/registry/resolution.py index 99b5c61b53..0a207f893e 100644 --- a/pyrit/registry/resolution.py +++ b/pyrit/registry/resolution.py @@ -185,72 +185,48 @@ def get_names(self) -> list[str]: ... -# TODO (Phase 4 — Target/Scorer migration): this function is deliberately left -# in its current, slightly awkward shape until Target/Scorer become unified -# ``Registry`` instances. It wants to be a flat ``ComponentType -> Registry class`` -# mapping, but it can't be one yet because the three families don't share a uniform -# name->instance surface: ``ConverterRegistry`` is a ``Registry`` whose instances -# live under ``.instances``, while ``TargetRegistry``/``ScorerRegistry`` are still -# legacy object registries whose singleton *is* the instance registry (hence the -# ``.instances`` hop for converters but not the others). Once Target/Scorer migrate -# onto ``Registry`` + ``.instances`` (Phase 4), collapse this into a single mapping -# to the registry classes and fold ``is_component_type_resolvable`` into the base -# ``Registry`` as a private method. def _registry_getter_for_component_type(component_type: ComponentType) -> Callable[[], _NamedInstanceRegistry] | None: """ - Return the getter for the registry singleton that resolves a component family. + Return the getter for the instance registry that resolves a component family. This is the one place that must import the concrete registries, so it stays in the resolve layer (the derive layer never imports them). It is the inverse of the identifier's self-reported ``component_type``: given that family, return the - registry that resolves its references by name. + ``.instances`` container that resolves its references by name. + + The three component registries share a uniform surface — each is a ``Registry`` + whose pre-configured instances live under ``.instances`` — so the mapping is a + flat ``ComponentType -> Registry class`` lookup. Returns: Callable[[], _NamedInstanceRegistry] | None: The registry getter, or None when no registry is wired for ``component_type``. """ - from pyrit.registry.components import ConverterRegistry - from pyrit.registry.object_registries import ScorerRegistry, TargetRegistry - - if component_type is ComponentType.TARGET: - return TargetRegistry.get_registry_singleton - if component_type is ComponentType.CONVERTER: - return lambda: ConverterRegistry.get_registry_singleton().instances - if component_type is ComponentType.SCORER: - return ScorerRegistry.get_registry_singleton - return None - - -def is_component_type_resolvable(component_type: ComponentType) -> bool: - """ - Return whether a registry is wired to resolve references of ``component_type``. - - This is the registration-time gate used by buildable registries: a reference - parameter whose component type has no paired registry can never be resolved by - name and should fail fast instead of erroring only at build time. + from pyrit.registry.components import ConverterRegistry, ScorerRegistry, TargetRegistry - NOTE: This belongs on the ``Registry`` base as a private method; it lives here - for now only because it wraps ``_registry_getter_for_component_type``. Both move - together in Phase 4 (see that function's note). - - Returns: - bool: True when references of ``component_type`` can be resolved by name. - """ - return _registry_getter_for_component_type(component_type) is not None + registry_classes = { + ComponentType.TARGET: TargetRegistry, + ComponentType.CONVERTER: ConverterRegistry, + ComponentType.SCORER: ScorerRegistry, + } + registry_class = registry_classes.get(component_type) + if registry_class is None: + return None + return lambda: registry_class.get_registry_singleton().instances -def _resolve_registry_reference( +def _resolve_single_reference( *, value: Any, getter: Callable[[], _NamedInstanceRegistry], owner: str, name: str ) -> Any: """ - Resolve a registry-reference parameter value to a stored instance. + Resolve a single registry-reference value to a stored instance. A string value is looked up by name in the paired registry. An already-built instance passes through unchanged. Args: value (Any): The raw value (a registry name, or an instance to pass through). - getter (Callable[[], _NamedInstanceRegistry]): Returns the registry singleton. + getter (Callable[[], _NamedInstanceRegistry]): Returns the instance registry. owner (str): The owning class name, for error messages. name (str): The parameter name, for error messages. @@ -281,6 +257,58 @@ def _resolve_registry_reference( ) +def _resolve_registry_reference( + *, + value: Any, + getter: Callable[[], _NamedInstanceRegistry], + owner: str, + name: str, + annotation: TypeAnnotation = None, +) -> Any: + """ + Resolve a registry-reference parameter value to stored instance(s). + + A scalar reference resolves a single name (or instance). A reference whose + constructor annotation is a ``list[...]`` resolves a list of names element by + element, so a multi-target (``RoundRobinTarget``) or a composite scorer can be + built from a list of registry names. Each element is resolved by + ``_resolve_single_reference`` (string → lookup, instance → passthrough). + + The value's shape must match the reference's arity: a ``list[...]`` reference + requires a list and a scalar reference rejects one, so a shape mismatch fails + here with a clear message instead of constructing the component with the wrong + argument shape and erroring obscurely downstream. + + Args: + value (Any): The raw value (a name, an instance, or a list of either). + getter (Callable[[], _NamedInstanceRegistry]): Returns the instance registry. + owner (str): The owning class name, for error messages. + name (str): The parameter name, for error messages. + annotation (TypeAnnotation): The constructor parameter's type annotation, + used to detect a ``list[...]`` reference. + + Returns: + Any: The resolved instance, or a list of resolved instances. + + Raises: + ValueError: If a name is not registered, or the value's shape (list vs. + scalar) does not match the reference's arity. + """ + if get_origin(annotation) is list: + if not isinstance(value, list): + raise ValueError( + f"{owner}.{name}: expected a list of registry names or instances for this " + f'reference, but got {type(value).__name__}. Pass a list, e.g. {name}=["a", "b"].' + ) + return [_resolve_single_reference(value=item, getter=getter, owner=owner, name=name) for item in value] + if isinstance(value, list): + raise ValueError( + f"{owner}.{name}: expected a single registry name or instance for this reference, " + f'but got a list. Pass a single value, e.g. {name}="a".' + ) + return _resolve_single_reference(value=value, getter=getter, owner=owner, name=name) + + def resolve_constructor_args( *, cls: type, raw_args: dict[str, Any], identifier_type: type[ComponentIdentifier] | None = None ) -> dict[str, Any]: @@ -324,7 +352,13 @@ def resolve_constructor_args( f"{cls.__name__}.{name}: no registry is wired for component type " f"'{param.reference.component_type}'." ) - resolved[name] = _resolve_registry_reference(value=value, getter=getter, owner=cls.__name__, name=name) + resolved[name] = _resolve_registry_reference( + value=value, + getter=getter, + owner=cls.__name__, + name=name, + annotation=param.reference.annotation, + ) elif isinstance(value, str) and param.is_string_coercible: try: resolved[name] = param.coerce_value(value) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index c920d2cbf7..b912793639 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -379,7 +379,9 @@ def _get_default_objective_scorer(self) -> TrueFalseScorer: # if available either itself, or its chat target will be used chat_target: PromptTarget | None = None registry_default_scorer: TrueFalseScorer | None = None - entries = ScorerRegistry.get_registry_singleton().get_by_tag(tag=ScorerInitializerTags.DEFAULT_OBJECTIVE_SCORER) + entries = ScorerRegistry.get_registry_singleton().instances.get_by_tag( + tag=ScorerInitializerTags.DEFAULT_OBJECTIVE_SCORER + ) if entries and isinstance(entries[0].instance, TrueFalseScorer): registry_default_scorer = entries[0].instance chat_target = registry_default_scorer.get_chat_target() diff --git a/pyrit/scenario/core/scenario_target_defaults.py b/pyrit/scenario/core/scenario_target_defaults.py index b737f9e290..b979279bfe 100644 --- a/pyrit/scenario/core/scenario_target_defaults.py +++ b/pyrit/scenario/core/scenario_target_defaults.py @@ -76,7 +76,7 @@ def _get_default_chat_target( ValueError: If the registry entry exists but is not a PromptTarget. """ registry = TargetRegistry.get_registry_singleton() - target = registry.get(preferred_target_key) + target = registry.instances.get(preferred_target_key) if target is not None: # Check required capabilities first (fail fast) if required_capabilities: diff --git a/pyrit/scenario/scenarios/benchmark/adversarial.py b/pyrit/scenario/scenarios/benchmark/adversarial.py index b0e0a58070..fd765b3e93 100644 --- a/pyrit/scenario/scenarios/benchmark/adversarial.py +++ b/pyrit/scenario/scenarios/benchmark/adversarial.py @@ -83,7 +83,7 @@ class AdversarialBenchmark(Scenario): parameter (declared in ``supported_parameters``). Each target must already be registered in ``TargetRegistry`` — typically by ``TargetInitializer`` from ``ADVERSARIAL_CHAT_*`` env vars, or - programmatically via ``TargetRegistry.register_instance``. + programmatically via ``TargetRegistry.get_registry_singleton().instances.register``. At run time, ``_get_atomic_attacks_async`` performs the ``(technique × adversarial_target × dataset)`` cross-product: for each @@ -130,7 +130,7 @@ def supported_parameters(cls) -> list[Parameter]: description=( "Registry names of adversarial chat targets to benchmark. " "Each name must already be registered in TargetRegistry " - "(via TargetInitializer or TargetRegistry.register_instance). " + "(via TargetInitializer or TargetRegistry instance registration). " "Use 'pyrit_scan list-targets' to see registered targets. " "Settable via --adversarial-targets [ ...] on the CLI, " "or scenario.args.adversarial_targets in .pyrit_conf." @@ -333,14 +333,14 @@ def _resolve_adversarial_targets(self, *, target_names: list[str]) -> list[tuple resolved: list[tuple[str, PromptTarget]] = [] unknown: list[str] = [] for name in target_names: - instance = target_registry.get_instance_by_name(name) + instance = target_registry.instances.get(name) if instance is None: unknown.append(name) else: resolved.append((name, instance)) if unknown: - available = sorted(target_registry.get_names()) + available = sorted(target_registry.instances.get_names()) raise ValueError( f"AdversarialBenchmark: adversarial_targets {sorted(unknown)} not found in TargetRegistry. " f"Available targets: {available}." diff --git a/pyrit/setup/initializers/components/scorers.py b/pyrit/setup/initializers/components/scorers.py index f6949bc55a..91c7296aa8 100644 --- a/pyrit/setup/initializers/components/scorers.py +++ b/pyrit/setup/initializers/components/scorers.py @@ -186,7 +186,7 @@ async def initialize_async(self) -> None: """ target_registry = TargetRegistry.get_registry_singleton() - if len(target_registry) == 0: + if len(target_registry.instances) == 0: raise RuntimeError( "TargetRegistry is empty. TargetInitializer must run before ScorerInitializer. " "Ensure TargetInitializer is included in the initializers list." @@ -464,16 +464,16 @@ def _tag_best_per_category(self) -> None: scorer_registry = self._get_scorer_registry() for best_tag, (preferred_name, category_tag) in self._PREFERRED_BEST.items(): - entry = scorer_registry.get_entry(preferred_name) + entry = scorer_registry.instances.get_entry(preferred_name) if entry is not None: - scorer_registry.add_tags(name=preferred_name, tags=[best_tag]) + scorer_registry.instances.add_tags(name=preferred_name, tags=[best_tag]) logger.info(f"Tagged {preferred_name} as {best_tag}") continue # Fallback: first registered scorer in this category - entries = scorer_registry.get_by_tag(tag=category_tag) + entries = scorer_registry.instances.get_by_tag(tag=category_tag) if entries: - scorer_registry.add_tags(name=entries[0].name, tags=[best_tag]) + scorer_registry.instances.add_tags(name=entries[0].name, tags=[best_tag]) logger.info(f"Tagged {entries[0].name} as {best_tag} (fallback)") else: logger.warning(f"No scorers in category {category_tag}; skipping {best_tag} tagging.") @@ -557,7 +557,7 @@ def _tag_best_objective(self) -> None: best_name: str | None = None best_f1: float = -1.0 - for entry in scorer_registry.get_all_instances(): + for entry in scorer_registry.instances.get_all_instances(): eval_hash = entry.instance.get_identifier().eval_hash if not eval_hash: continue @@ -567,7 +567,7 @@ def _tag_best_objective(self) -> None: best_name = entry.name if best_name is not None: - scorer_registry.add_tags( + scorer_registry.instances.add_tags( name=best_name, tags=[ScorerInitializerTags.BEST_OBJECTIVE, ScorerInitializerTags.DEFAULT_OBJECTIVE_SCORER], ) @@ -576,13 +576,13 @@ def _tag_best_objective(self) -> None: # Fall back: prefer scale_and_refusal, then first composite best_tags: list[str] = [ScorerInitializerTags.BEST_OBJECTIVE, ScorerInitializerTags.DEFAULT_OBJECTIVE_SCORER] - if scorer_registry.get_entry(self.SCALE_AND_REFUSAL): - scorer_registry.add_tags(name=self.SCALE_AND_REFUSAL, tags=best_tags) + if scorer_registry.instances.get_entry(self.SCALE_AND_REFUSAL): + scorer_registry.instances.add_tags(name=self.SCALE_AND_REFUSAL, tags=best_tags) logger.info(f"Tagged {self.SCALE_AND_REFUSAL} as {ScorerInitializerTags.BEST_OBJECTIVE} (default)") else: - composites = scorer_registry.get_by_tag(tag=ScorerInitializerTags.OBJECTIVE_COMPOSITE) + composites = scorer_registry.instances.get_by_tag(tag=ScorerInitializerTags.OBJECTIVE_COMPOSITE) if composites: - scorer_registry.add_tags(name=composites[0].name, tags=best_tags) + scorer_registry.instances.add_tags(name=composites[0].name, tags=best_tags) logger.info(f"Tagged {composites[0].name} as {ScorerInitializerTags.BEST_OBJECTIVE} (fallback)") else: logger.warning("No composite scorers available; skipping best objective tagging.") @@ -598,7 +598,7 @@ def _get_best_scorer(self, best_tag: str) -> Scorer | None: Returns: Scorer | None: The scorer instance if found, otherwise None. """ - entries = self._get_scorer_registry().get_by_tag(tag=best_tag) + entries = self._get_scorer_registry().instances.get_by_tag(tag=best_tag) return entries[0].instance if entries else None def _get_registered_scorer(self, name: str) -> Scorer | None: @@ -608,7 +608,7 @@ def _get_registered_scorer(self, name: str) -> Scorer | None: Returns: Scorer | None: The scorer instance if found, otherwise None. """ - entry = self._get_scorer_registry().get_entry(name) + entry = self._get_scorer_registry().instances.get_entry(name) return entry.instance if entry else None def _get_scorer_registry(self) -> ScorerRegistry: @@ -628,7 +628,7 @@ def _get_chat_target(self, target_name: str) -> "PromptTarget | None": PromptTarget | None: The chat target instance if found, otherwise None. """ target_registry = TargetRegistry.get_registry_singleton() - return target_registry.get_instance_by_name(target_name) + return target_registry.instances.get(target_name) def _get_chat_target_prefer_rr(self, target_name: str) -> "PromptTarget | None": """ @@ -650,12 +650,12 @@ def _get_chat_target_prefer_rr(self, target_name: str) -> "PromptTarget | None": from pyrit.setup.initializers.components.targets import generate_rr_name, get_behavioral_key target_registry = TargetRegistry.get_registry_singleton() - individual = target_registry.get_instance_by_name(target_name) + individual = target_registry.instances.get(target_name) if individual is None: return None rr_name = generate_rr_name(get_behavioral_key(individual)) - rr_target = target_registry.get_instance_by_name(rr_name) + rr_target = target_registry.instances.get(rr_name) if rr_target is not None: return rr_target @@ -700,7 +700,7 @@ def _try_register( try: scorer = factory() - scorer_registry.register_instance(scorer, name=name, tags=list(tags) if tags else None) + scorer_registry.instances.register(scorer, name=name, tags=list(tags) if tags else None) logger.info(f"Registered scorer: {name}") except (ValueError, TypeError, KeyError) as e: logger.warning(f"Skipping scorer {name}: {e}") diff --git a/pyrit/setup/initializers/components/targets.py b/pyrit/setup/initializers/components/targets.py index 2cf3981546..258a18a009 100644 --- a/pyrit/setup/initializers/components/targets.py +++ b/pyrit/setup/initializers/components/targets.py @@ -676,11 +676,13 @@ def _register_target(self, config: TargetConfig) -> None: target = config.target_class(**kwargs) registry = TargetRegistry.get_registry_singleton() - registry.register_instance(target, name=config.registry_name) + registry.instances.register(target, name=config.registry_name) if config.tags: - registry.add_tags(name=config.registry_name, tags=list(config.tags)) + registry.instances.add_tags(name=config.registry_name, tags=list(config.tags)) if config.default_objective_target: - registry.add_tags(name=config.registry_name, tags=[TargetInitializerTags.DEFAULT_OBJECTIVE_TARGET]) + registry.instances.add_tags( + name=config.registry_name, tags=[TargetInitializerTags.DEFAULT_OBJECTIVE_TARGET] + ) self._registered_names.append(config.registry_name) logger.info(f"Registered target: {config.registry_name}") @@ -703,7 +705,7 @@ def _auto_group_targets(self) -> None: # Group registered targets by behavioral key. groups: dict[tuple[Any, ...], list[tuple[str, PromptTarget]]] = defaultdict(list) for name in self._registered_names: - target = registry.get_instance_by_name(name) + target = registry.instances.get(name) if target is None: continue key = get_behavioral_key(target) @@ -741,11 +743,11 @@ def _auto_group_targets(self) -> None: rr_name = generate_rr_name(key) - if rr_name in registry: + if rr_name in registry.instances: logger.debug(f"Skipping auto-group {rr_name}: name already exists in registry") continue - registry.register_instance(rr_target, name=rr_name) + registry.instances.register(rr_target, name=rr_name) logger.info(f"Auto-grouped round-robin target: {rr_name} (members: {member_names})") diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index 4cc21e52fd..20e011f64b 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -575,7 +575,6 @@ def _try_instantiate_converter(converter_name: str): (isinstance(ann, type) and issubclass(ann, PromptTarget)) or "PromptTarget" in ann_str ): mock_target = MagicMock(spec=PromptTarget) - mock_target.__class__.__name__ = "MockChatTarget" # Configure get_identifier() to return a real identifier so that # _create_identifier can promote it into the typed child slot. mock_id = ComponentIdentifier( diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 15116a0ac9..a85a2c105a 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -113,8 +113,8 @@ def mock_all_registries(mock_memory): mock_sr.create_instance.return_value = mock_scenario_instance mock_tr = MagicMock() - mock_tr.get_instance_by_name.return_value = MagicMock() - mock_tr.get_names.return_value = ["my_target"] + mock_tr.instances.get.return_value = MagicMock() + mock_tr.instances.get_names.return_value = ["my_target"] mock_ir = MagicMock() mock_ir.get_class.return_value = MagicMock(return_value=MagicMock(initialize_async=AsyncMock())) @@ -174,8 +174,8 @@ async def test_start_run_invalid_target_raises_value_error(self, mock_memory) -> mock_sr.get_class.return_value = MagicMock() mock_tr = MagicMock() - mock_tr.get_instance_by_name.return_value = None - mock_tr.get_names.return_value = ["other_target"] + mock_tr.instances.get.return_value = None + mock_tr.instances.get_names.return_value = ["other_target"] with ( patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), @@ -217,7 +217,7 @@ async def test_start_run_invalid_strategy_raises_value_error(self, mock_memory) mock_sr.get_class.return_value = mock_scenario_class mock_tr = MagicMock() - mock_tr.get_instance_by_name.return_value = MagicMock() + mock_tr.instances.get.return_value = MagicMock() with ( patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), @@ -238,7 +238,7 @@ async def test_start_run_scenario_not_no_arg_instantiable_raises(self, mock_memo mock_sr.get_class.return_value = mock_scenario_class mock_tr = MagicMock() - mock_tr.get_instance_by_name.return_value = MagicMock() + mock_tr.instances.get.return_value = MagicMock() with ( patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), @@ -379,8 +379,8 @@ async def test_start_run_dataset_names_introspection_failure_raises(self, mock_m mock_sr.get_class.return_value = mock_scenario_class mock_tr = MagicMock() - mock_tr.get_instance_by_name.return_value = MagicMock() - mock_tr.get_names.return_value = ["my_target"] + mock_tr.instances.get.return_value = MagicMock() + mock_tr.instances.get_names.return_value = ["my_target"] mock_ir = MagicMock() diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index b86467eb88..f8926b5778 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -13,15 +13,16 @@ from pyrit.backend.models.targets import CreateTargetRequest from pyrit.backend.services.target_service import TargetService, get_target_service from pyrit.models import ComponentIdentifier -from pyrit.registry.object_registries import TargetRegistry +from pyrit.prompt_target import PromptTarget +from pyrit.registry import TargetRegistry @pytest.fixture(autouse=True) def reset_registry(): """Reset the TargetRegistry singleton before each test.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() yield - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() def _mock_target_identifier(*, class_name: str = "MockTarget", **kwargs) -> ComponentIdentifier: @@ -64,9 +65,9 @@ async def test_list_targets_returns_targets_from_registry(self) -> None: service = TargetService() # Register a mock target - mock_target = MagicMock() + mock_target = MagicMock(spec=PromptTarget) mock_target.get_identifier.return_value = _mock_target_identifier(endpoint="http://test") - service._registry.register_instance(mock_target, name="target-1") + service._registry.instances.register(mock_target, name="target-1") result = await service.list_targets_async() @@ -80,9 +81,9 @@ async def test_list_targets_paginates_with_limit(self) -> None: service = TargetService() for i in range(5): - mock_target = MagicMock() + mock_target = MagicMock(spec=PromptTarget) mock_target.get_identifier.return_value = _mock_target_identifier() - service._registry.register_instance(mock_target, name=f"target-{i}") + service._registry.instances.register(mock_target, name=f"target-{i}") result = await service.list_targets_async(limit=3) @@ -96,9 +97,9 @@ async def test_list_targets_cursor_returns_next_page(self) -> None: service = TargetService() for i in range(5): - mock_target = MagicMock() + mock_target = MagicMock(spec=PromptTarget) mock_target.get_identifier.return_value = _mock_target_identifier() - service._registry.register_instance(mock_target, name=f"target-{i}") + service._registry.instances.register(mock_target, name=f"target-{i}") first_page = await service.list_targets_async(limit=2) second_page = await service.list_targets_async(limit=2, cursor=first_page.pagination.next_cursor) @@ -112,9 +113,9 @@ async def test_list_targets_last_page_has_no_more(self) -> None: service = TargetService() for i in range(3): - mock_target = MagicMock() + mock_target = MagicMock(spec=PromptTarget) mock_target.get_identifier.return_value = _mock_target_identifier() - service._registry.register_instance(mock_target, name=f"target-{i}") + service._registry.instances.register(mock_target, name=f"target-{i}") first_page = await service.list_targets_async(limit=2) last_page = await service.list_targets_async(limit=2, cursor=first_page.pagination.next_cursor) @@ -139,9 +140,9 @@ async def test_get_target_returns_target_from_registry(self) -> None: """Test that get_target returns target built from registry object.""" service = TargetService() - mock_target = MagicMock() + mock_target = MagicMock(spec=PromptTarget) mock_target.get_identifier.return_value = _mock_target_identifier() - service._registry.register_instance(mock_target, name="target-1") + service._registry.instances.register(mock_target, name="target-1") result = await service.get_target_async(target_registry_name="target-1") @@ -153,7 +154,7 @@ async def test_list_targets_includes_extra_params_in_target_specific(self) -> No """Test that extra identifier params (reasoning_effort etc.) appear in target_specific_params.""" service = TargetService() - mock_target = MagicMock() + mock_target = MagicMock(spec=PromptTarget) identifier = ComponentIdentifier( class_name="OpenAIResponseTarget", class_module="pyrit.prompt_target", @@ -167,7 +168,7 @@ async def test_list_targets_includes_extra_params_in_target_specific(self) -> No }, ) mock_target.get_identifier.return_value = identifier - service._registry.register_instance(mock_target, name="response-target") + service._registry.instances.register(mock_target, name="response-target") result = await service.list_targets_async() @@ -183,7 +184,7 @@ async def test_get_target_includes_extra_params_in_target_specific(self) -> None """Test that get_target returns target_specific_params with extra identifier params.""" service = TargetService() - mock_target = MagicMock() + mock_target = MagicMock(spec=PromptTarget) identifier = ComponentIdentifier( class_name="OpenAIChatTarget", class_module="pyrit.prompt_target", @@ -195,7 +196,7 @@ async def test_get_target_includes_extra_params_in_target_specific(self) -> None }, ) mock_target.get_identifier.return_value = identifier - service._registry.register_instance(mock_target, name="chat-target") + service._registry.instances.register(mock_target, name="chat-target") result = await service.get_target_async(target_registry_name="chat-target") @@ -219,8 +220,8 @@ def test_get_target_object_returns_none_for_nonexistent(self) -> None: def test_get_target_object_returns_object_from_registry(self) -> None: """Test that get_target_object returns the actual target object.""" service = TargetService() - mock_target = MagicMock() - service._registry.register_instance(mock_target, name="target-1") + mock_target = MagicMock(spec=PromptTarget) + service._registry.instances.register(mock_target, name="target-1") result = service.get_target_object(target_registry_name="target-1") @@ -610,20 +611,20 @@ async def test_create_round_robin_target_resolves_registry_names(self, sqlite_in # (same class, multi-turn, editable history) that requires real compatible # targets. The service's job is to resolve registry names and pass them # through — the constructor validation is tested in RoundRobinTarget's own tests. - mock_a = MagicMock() + mock_a = MagicMock(spec=PromptTarget) mock_a.get_identifier.return_value = _mock_target_identifier( class_name="OpenAIChatTarget", endpoint="https://a.openai.azure.com", model_name="gpt-4o" ) - mock_b = MagicMock() + mock_b = MagicMock(spec=PromptTarget) mock_b.get_identifier.return_value = _mock_target_identifier( class_name="OpenAIChatTarget", endpoint="https://b.openai.azure.com", model_name="gpt-4o" ) - service._registry.register_instance(mock_a, name="target-a") - service._registry.register_instance(mock_b, name="target-b") + service._registry.instances.register(mock_a, name="target-a") + service._registry.instances.register(mock_b, name="target-b") # Patch RoundRobinTarget so the constructor returns a mock that behaves # like a registered target (has get_identifier, capabilities, etc.) - mock_rr = MagicMock() + mock_rr = MagicMock(spec=PromptTarget) mock_rr.get_identifier.return_value = ComponentIdentifier( class_name="RoundRobinTarget", class_module="pyrit.prompt_target.round_robin_target", @@ -683,21 +684,21 @@ async def test_create_round_robin_target_deduplicates_identical_targets(self, sq identifier_a = _mock_target_identifier( class_name="OpenAIChatTarget", endpoint="https://a.openai.azure.com", model_name="gpt-4o" ) - mock_a = MagicMock() + mock_a = MagicMock(spec=PromptTarget) mock_a.get_identifier.return_value = identifier_a - mock_a_alias = MagicMock() + mock_a_alias = MagicMock(spec=PromptTarget) mock_a_alias.get_identifier.return_value = identifier_a - mock_b = MagicMock() + mock_b = MagicMock(spec=PromptTarget) mock_b.get_identifier.return_value = _mock_target_identifier( class_name="OpenAIChatTarget", endpoint="https://b.openai.azure.com", model_name="gpt-4o" ) - service._registry.register_instance(mock_a, name="target-a") - service._registry.register_instance(mock_a_alias, name="target-a-alias") - service._registry.register_instance(mock_b, name="target-b") + service._registry.instances.register(mock_a, name="target-a") + service._registry.instances.register(mock_a_alias, name="target-a-alias") + service._registry.instances.register(mock_b, name="target-b") - mock_rr = MagicMock() + mock_rr = MagicMock(spec=PromptTarget) mock_rr.get_identifier.return_value = ComponentIdentifier( class_name="RoundRobinTarget", class_module="pyrit.prompt_target.round_robin_target", @@ -729,13 +730,13 @@ async def test_create_round_robin_target_all_duplicates_raises(self, sqlite_inst identifier = _mock_target_identifier( class_name="OpenAIChatTarget", endpoint="https://a.openai.azure.com", model_name="gpt-4o" ) - mock_a = MagicMock() + mock_a = MagicMock(spec=PromptTarget) mock_a.get_identifier.return_value = identifier - mock_a_alias = MagicMock() + mock_a_alias = MagicMock(spec=PromptTarget) mock_a_alias.get_identifier.return_value = identifier - service._registry.register_instance(mock_a, name="target-a") - service._registry.register_instance(mock_a_alias, name="target-a-alias") + service._registry.instances.register(mock_a, name="target-a") + service._registry.instances.register(mock_a_alias, name="target-a-alias") rr_request = CreateTargetRequest( type="RoundRobinTarget", diff --git a/tests/unit/registry/test_attack_technique_registry.py b/tests/unit/registry/test_attack_technique_registry.py index 61c90b7bc3..43a5cb2d7c 100644 --- a/tests/unit/registry/test_attack_technique_registry.py +++ b/tests/unit/registry/test_attack_technique_registry.py @@ -280,14 +280,14 @@ def _scenario_factories() -> list[AttackTechniqueFactory]: not depend on environment variables or OpenAIChatTarget. """ if not SCENARIO_FACTORIES_FIXTURE: - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() adv_target = MagicMock(spec=PromptTarget) adv_target.capabilities.includes.return_value = True - TargetRegistry.get_registry_singleton().register_instance(adv_target, name="adversarial_chat") + TargetRegistry.get_registry_singleton().instances.register(adv_target, name="adversarial_chat") SCENARIO_FACTORIES_FIXTURE.extend(build_scenario_technique_factories()) # This runs at collection time (parametrize). Reset so we don't leak the mock # "adversarial_chat" into the global TargetRegistry singleton of every xdist worker. - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() return SCENARIO_FACTORIES_FIXTURE diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py index e36f8f3393..533ee95e1a 100644 --- a/tests/unit/registry/test_base_instance_registry.py +++ b/tests/unit/registry/test_base_instance_registry.py @@ -8,9 +8,6 @@ BaseInstanceRegistry, RegistryEntry, ) -from pyrit.registry.object_registries.retrievable_instance_registry import ( - RetrievableInstanceRegistry, -) class _TestItem(Identifiable): @@ -45,12 +42,29 @@ def _item(value: str) -> _TestItem: return _TestItem(value) -class ConcreteTestRegistry(RetrievableInstanceRegistry["_TestItem"]): - """Concrete implementation of RetrievableInstanceRegistry for testing.""" +class ConcreteTestRegistry(BaseInstanceRegistry["_TestItem"]): + """Concrete instance-holding registry (legacy base) used as a test double. + + Defines the direct-retrieval helpers (``get``/``get_entry``/ + ``get_all_instances``) so the shared ``BaseInstanceRegistry`` infrastructure + can be exercised through a retrievable surface. The canonical retrievable + implementation now lives on ``DefaultInstanceRegistry`` (the ``.instances`` + property); this double mirrors it only for these legacy-base tests. + """ + + def get(self, name: str) -> "_TestItem | None": + entry = self._registry_items.get(name) + return None if entry is None else entry.instance + + def get_entry(self, name: str) -> "RegistryEntry[_TestItem] | None": + return self._registry_items.get(name) + def get_all_instances(self) -> "list[RegistryEntry[_TestItem]]": + return [self._registry_items[name] for name in sorted(self._registry_items.keys())] -class TestRetrievableInstanceRegistrySingleton: - """Tests for the singleton pattern in RetrievableInstanceRegistry.""" + +class TestConcreteInstanceRegistrySingleton: + """Tests for the singleton pattern in the concrete instance registry.""" def setup_method(self): """Reset the singleton before each test.""" @@ -82,8 +96,8 @@ def test_reset_instance_when_not_exists_does_not_raise(self): ConcreteTestRegistry.reset_instance() -class TestRetrievableInstanceRegistryRegistration: - """Tests for registration functionality in RetrievableInstanceRegistry.""" +class TestConcreteInstanceRegistryRegistration: + """Tests for registration functionality in the concrete instance registry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -134,8 +148,8 @@ def test_register_invalidates_metadata_cache(self): assert len(metadata2) == 2 -class TestRetrievableInstanceRegistryGet: - """Tests for get functionality in RetrievableInstanceRegistry.""" +class TestConcreteInstanceRegistryGet: + """Tests for get functionality in the concrete instance registry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -158,8 +172,8 @@ def test_get_nonexistent_returns_none(self): assert result is None -class TestRetrievableInstanceRegistryGetEntry: - """Tests for get_entry functionality in RetrievableInstanceRegistry.""" +class TestConcreteInstanceRegistryGetEntry: + """Tests for get_entry functionality in the concrete instance registry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -186,8 +200,8 @@ def test_get_entry_nonexistent_returns_none(self): assert result is None -class TestRetrievableInstanceRegistryGetNames: - """Tests for get_names functionality in RetrievableInstanceRegistry.""" +class TestConcreteInstanceRegistryGetNames: + """Tests for get_names functionality in the concrete instance registry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -213,8 +227,8 @@ def test_get_names_returns_sorted_list(self): assert names == ["alpha", "beta", "zeta"] -class TestRetrievableInstanceRegistryGetAllInstances: - """Tests for get_all_instances functionality in RetrievableInstanceRegistry.""" +class TestConcreteInstanceRegistryGetAllInstances: + """Tests for get_all_instances functionality in the concrete instance registry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -260,8 +274,8 @@ def test_get_all_instances_empty_registry(self): assert result == [] -class TestRetrievableInstanceRegistryListMetadata: - """Tests for list_metadata functionality in RetrievableInstanceRegistry.""" +class TestConcreteInstanceRegistryListMetadata: + """Tests for list_metadata functionality in the concrete instance registry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -331,8 +345,8 @@ def test_list_metadata_caching(self): assert len(metadata1) == 3 -class TestRetrievableInstanceRegistryTags: - """Tests for tag registration and retrieval in RetrievableInstanceRegistry.""" +class TestConcreteInstanceRegistryTags: + """Tests for tag registration and retrieval in the concrete instance registry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -429,11 +443,11 @@ def test_get_by_tag_with_list_tags_value_empty_string(self): def test_normalize_tags_none(self): """Test _normalize_tags returns empty dict for None.""" - assert RetrievableInstanceRegistry._normalize_tags(None) == {} + assert BaseInstanceRegistry._normalize_tags(None) == {} def test_normalize_tags_list(self): """Test _normalize_tags converts list to dict with empty values.""" - assert RetrievableInstanceRegistry._normalize_tags(["a", "b"]) == {"a": "", "b": ""} + assert BaseInstanceRegistry._normalize_tags(["a", "b"]) == {"a": "", "b": ""} def test_normalize_tags_dict(self): """Test _normalize_tags returns a copy of the dict.""" @@ -443,8 +457,8 @@ def test_normalize_tags_dict(self): assert result is not original -class TestRetrievableInstanceRegistryDunderMethods: - """Tests for dunder methods (__contains__, __len__, __iter__) in RetrievableInstanceRegistry.""" +class TestConcreteInstanceRegistryDunderMethods: + """Tests for dunder methods (__contains__, __len__, __iter__) in the concrete instance registry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -506,18 +520,6 @@ def test_item_registry_has_no_get_all_instances(self): """BaseInstanceRegistry subclasses should not have a get_all_instances() method.""" assert not hasattr(_ItemOnlyRegistry, "get_all_instances") - def test_instance_registry_has_get(self): - """RetrievableInstanceRegistry subclasses should have get().""" - assert hasattr(ConcreteTestRegistry, "get") - - def test_instance_registry_has_get_entry(self): - """RetrievableInstanceRegistry subclasses should have get_entry().""" - assert hasattr(ConcreteTestRegistry, "get_entry") - - def test_instance_registry_has_get_all_instances(self): - """RetrievableInstanceRegistry subclasses should have get_all_instances().""" - assert hasattr(ConcreteTestRegistry, "get_all_instances") - def test_item_registry_shares_common_methods(self): """BaseInstanceRegistry subclasses should have shared registry methods.""" for method in ( @@ -533,8 +535,8 @@ def test_item_registry_shares_common_methods(self): assert hasattr(_ItemOnlyRegistry, method), f"Missing method: {method}" -class TestRetrievableInstanceRegistryAddTags: - """Tests for add_tags functionality in RetrievableInstanceRegistry.""" +class TestConcreteInstanceRegistryAddTags: + """Tests for add_tags functionality in the concrete instance registry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -612,7 +614,7 @@ class IdentifierTestRegistry(BaseInstanceRegistry["_IdentifiableStub"]): class TestFindDependentsOfTag: - """Tests for RetrievableInstanceRegistry.find_dependents_of_tag.""" + """Tests for BaseInstanceRegistry.find_dependents_of_tag.""" def setup_method(self) -> None: IdentifierTestRegistry.reset_instance() @@ -720,7 +722,7 @@ def test_tagged_entries_without_eval_hash_returns_empty(self) -> None: assert self.registry.find_dependents_of_tag(tag="refusal") == [] -class TestRetrievableInstanceRegistryMetadataField: +class TestConcreteInstanceRegistryMetadataField: """Tests for the metadata field on RegistryEntry.""" def setup_method(self): diff --git a/tests/unit/registry/test_converter_registry.py b/tests/unit/registry/test_converter_registry.py index 62b6ab0762..0dad25611c 100644 --- a/tests/unit/registry/test_converter_registry.py +++ b/tests/unit/registry/test_converter_registry.py @@ -30,8 +30,6 @@ from pyrit.registry.components import ( ConverterMetadata, ConverterRegistry, -) -from pyrit.registry.object_registries import ( TargetRegistry, ) from pyrit.registry.resolution import derive_parameters @@ -314,22 +312,22 @@ class TestCreateLLMConverter: def test_build_llm_converter_resolves_target_by_name(self, registry: ConverterRegistry): target = MockPromptTarget() - TargetRegistry.reset_instance() - TargetRegistry.get_registry_singleton().register_instance(target, name="my_target") + TargetRegistry.reset_registry_singleton() + TargetRegistry.get_registry_singleton().instances.register(target, name="my_target") try: converter = registry.create_instance("TenseConverter", converter_target="my_target", tense="past") assert isinstance(converter, TenseConverter) assert converter._converter_target is target finally: - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() def test_build_llm_converter_unknown_target_raises(self, registry: ConverterRegistry): - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() try: with pytest.raises(ValueError, match="not found"): registry.create_instance("TenseConverter", converter_target="missing", tense="past") finally: - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() class TestClassMetadata: diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 177177d6f2..e5fa51105c 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -4,14 +4,16 @@ """ Unit tests for the standalone ``Registry`` base. -``ConverterRegistry`` overrides ``_get_registry_name`` and ``_identifier_type``, so +``ConverterRegistry`` overrides ``_identifier_type`` and supplies discovery hooks, so exercising the base only through it leaves the base's own defaults uncovered: -snake_case naming, the no-identifier path, eager vs. lazy discovery, the metadata +class-name keying, the no-identifier path, eager vs. lazy discovery, the metadata accessors, and the filter wiring. These tests drive a minimal subclass that keeps every base default. """ from dataclasses import dataclass, field +from types import ModuleType +from unittest.mock import MagicMock import pytest @@ -62,16 +64,16 @@ class _TaggedMetadata(ClassRegistryEntry): tags: tuple[str, ...] = field(kw_only=True, default=()) -def test_get_registry_name_defaults_to_snake_case(): +def test_get_registry_name_defaults_to_class_name(): registry = WidgetRegistry() - assert registry.get_class_names() == ["sample_widget", "undocumented_widget"] + assert registry.get_class_names() == ["SampleWidget", "UndocumentedWidget"] def test_build_metadata_uses_first_paragraph_summary(): registry = WidgetRegistry() - meta = registry.get_registered_class_metadata("sample_widget") + meta = registry.get_registered_class_metadata("SampleWidget") assert meta is not None assert meta.class_description == "A sample widget." @@ -82,7 +84,7 @@ def test_build_metadata_uses_first_paragraph_summary(): def test_build_metadata_empty_description_without_docstring(): registry = WidgetRegistry() - meta = registry.get_registered_class_metadata("undocumented_widget") + meta = registry.get_registered_class_metadata("UndocumentedWidget") assert meta is not None assert meta.class_description == "" @@ -91,7 +93,7 @@ def test_build_metadata_empty_description_without_docstring(): def test_class_attributes_empty_without_identifier_type(): registry = WidgetRegistry() - meta = registry.get_registered_class_metadata("sample_widget") + meta = registry.get_registered_class_metadata("SampleWidget") assert meta is not None assert meta.class_attributes == {} @@ -100,7 +102,7 @@ def test_class_attributes_empty_without_identifier_type(): def test_parameters_have_no_references_without_identifier_type(): registry = WidgetRegistry() - meta = registry.get_registered_class_metadata("sample_widget") + meta = registry.get_registered_class_metadata("SampleWidget") assert meta is not None assert all(p.reference is None for p in meta.parameters) @@ -109,7 +111,7 @@ def test_parameters_have_no_references_without_identifier_type(): def test_create_instance_builds_object(): registry = WidgetRegistry() - widget = registry.create_instance("sample_widget", size=3) + widget = registry.create_instance("SampleWidget", size=3) assert isinstance(widget, SampleWidget) assert widget.size == 3 @@ -141,8 +143,8 @@ def test_get_class_metadata_builds_for_unregistered_class(): meta = registry.get_class_metadata(UnregisteredWidget) assert meta.class_name == "UnregisteredWidget" - assert meta.registry_name == "unregistered_widget" - assert "unregistered_widget" not in registry.get_class_names() + assert meta.registry_name == "UnregisteredWidget" + assert "UnregisteredWidget" not in registry.get_class_names() def test_get_class_unknown_name_raises(): @@ -156,8 +158,8 @@ def test_iter_and_contains_and_len(): registry = WidgetRegistry() assert len(registry) == 2 - assert "sample_widget" in registry - assert list(registry) == ["sample_widget", "undocumented_widget"] + assert "SampleWidget" in registry + assert list(registry) == ["SampleWidget", "UndocumentedWidget"] def test_get_all_metadata_no_filters_returns_all(): @@ -165,23 +167,23 @@ def test_get_all_metadata_no_filters_returns_all(): all_meta = registry.get_all_registered_class_metadata() - assert {m.registry_name for m in all_meta} == {"sample_widget", "undocumented_widget"} + assert {m.registry_name for m in all_meta} == {"SampleWidget", "UndocumentedWidget"} def test_get_all_metadata_include_filter_matches_subset(): registry = WidgetRegistry() - result = registry.get_all_registered_class_metadata(include_filters={"registry_name": "sample_widget"}) + result = registry.get_all_registered_class_metadata(include_filters={"registry_name": "SampleWidget"}) - assert [m.registry_name for m in result] == ["sample_widget"] + assert [m.registry_name for m in result] == ["SampleWidget"] def test_get_all_metadata_exclude_filter_removes_match(): registry = WidgetRegistry() - result = registry.get_all_registered_class_metadata(exclude_filters={"registry_name": "sample_widget"}) + result = registry.get_all_registered_class_metadata(exclude_filters={"registry_name": "SampleWidget"}) - assert [m.registry_name for m in result] == ["undocumented_widget"] + assert [m.registry_name for m in result] == ["UndocumentedWidget"] def test_matches_filters_list_containment(): @@ -210,3 +212,42 @@ def __init__(self) -> None: missing_found, missing_value = _get_metadata_value(HasParams(), "missing") assert missing_found is False assert missing_value is None + + +class _WidgetBase: + """Base for the default-discovery hardening test.""" + + +class _ConcreteWidget(_WidgetBase): + """A concrete widget.""" + + +class _PackageDrivenRegistry(Registry[object, ClassRegistryEntry]): + """Registry that uses the base's default ``_discover`` over a supplied package.""" + + def __init__(self, *, package: ModuleType) -> None: + self._package = package + super().__init__(lazy_discovery=False) + + def _base_type(self) -> type: + return _WidgetBase + + def _discovery_package(self) -> ModuleType: + return self._package + + def _metadata_class(self) -> type[ClassRegistryEntry]: + return ClassRegistryEntry + + +def test_discover_skips_spec_type_mock_exports(): + # A foreign test may patch a discovery-package export with a ``MagicMock(spec=type)`` + # that reports ``isinstance(obj, type) is True`` yet makes ``issubclass`` raise + # ``TypeError``. Default discovery must skip it rather than blow up the whole catalog. + package = ModuleType("_fake_widget_package") + package.__all__ = ["_ConcreteWidget", "_LeakedMock"] + package._ConcreteWidget = _ConcreteWidget + package._LeakedMock = MagicMock(spec=type) + + registry = _PackageDrivenRegistry(package=package) + + assert registry.get_class_names() == ["_ConcreteWidget"] diff --git a/tests/unit/registry/test_resolution.py b/tests/unit/registry/test_resolution.py index 72400d692f..a37e0a3694 100644 --- a/tests/unit/registry/test_resolution.py +++ b/tests/unit/registry/test_resolution.py @@ -12,10 +12,10 @@ from pyrit.common import REQUIRED_VALUE from pyrit.common.apply_defaults import _RequiredValueSentinel from pyrit.models import Message, MessagePiece -from pyrit.models.identifiers import ConverterIdentifier +from pyrit.models.identifiers import ConverterIdentifier, TargetIdentifier from pyrit.models.parameter import ComponentType from pyrit.prompt_target import PromptTarget -from pyrit.registry.object_registries import TargetRegistry +from pyrit.registry.components import TargetRegistry from pyrit.registry.resolution import ( derive_parameters, display_choices, @@ -92,6 +92,13 @@ def __init__(self, *, converter_target: str = "x") -> None: self.converter_target = converter_target +class _NeedsTargets: + """Helper whose constructor takes a list-typed registry reference.""" + + def __init__(self, *, targets: list[PromptTarget]) -> None: + self.targets = targets + + def _resolve(cls: type, raw_args: dict[str, object], *, identifier_type: type | None = None) -> dict[str, object]: """Resolve ``raw_args`` against the derived parameter contract for ``cls``.""" return resolve_constructor_args(cls=cls, raw_args=raw_args, identifier_type=identifier_type) @@ -100,20 +107,20 @@ def _resolve(cls: type, raw_args: dict[str, object], *, identifier_type: type | @pytest.fixture def target_registry(): """Provide a fresh TargetRegistry singleton with one registered target.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() registry = TargetRegistry.get_registry_singleton() - registry.register_instance(MockPromptTarget(), name="my_target") + registry.instances.register(MockPromptTarget(), name="my_target") yield registry - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() @pytest.fixture def empty_target_registry(): """Provide a fresh, empty TargetRegistry singleton.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() registry = TargetRegistry.get_registry_singleton() yield registry - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() class TestDisplayChoices: @@ -161,9 +168,33 @@ def test_resolves_registry_reference_by_name(self, target_registry: TargetRegist resolved = _resolve( _NeedsTarget, {"converter_target": "my_target", "offset": "5"}, identifier_type=ConverterIdentifier ) - assert resolved["converter_target"] is target_registry.get_instance_by_name("my_target") + assert resolved["converter_target"] is target_registry.instances.get("my_target") assert resolved["offset"] == 5 + def test_resolves_list_registry_reference_by_name(self, target_registry: TargetRegistry) -> None: + # A ``list[...]`` reference resolves each element by name (the list-aware path + # used by RoundRobinTarget and composite scorers). + target_registry.instances.register(MockPromptTarget(), name="second_target") + resolved = _resolve( + _NeedsTargets, {"targets": ["my_target", "second_target"]}, identifier_type=TargetIdentifier + ) + assert resolved["targets"] == [ + target_registry.instances.get("my_target"), + target_registry.instances.get("second_target"), + ] + + def test_list_registry_reference_instance_passthrough(self, target_registry: TargetRegistry) -> None: + # Non-string elements (already-built instances) pass through unchanged, + # interleaved with names that are looked up. + instance = MockPromptTarget() + resolved = _resolve(_NeedsTargets, {"targets": ["my_target", instance]}, identifier_type=TargetIdentifier) + assert resolved["targets"][0] is target_registry.instances.get("my_target") + assert resolved["targets"][1] is instance + + def test_list_registry_reference_unknown_name_raises(self, target_registry: TargetRegistry) -> None: + with pytest.raises(ValueError, match="missing"): + _resolve(_NeedsTargets, {"targets": ["my_target", "missing"]}, identifier_type=TargetIdentifier) + def test_registry_reference_instance_passthrough(self, target_registry: TargetRegistry) -> None: instance = MockPromptTarget() resolved = _resolve(_NeedsTarget, {"converter_target": instance}, identifier_type=ConverterIdentifier) diff --git a/tests/unit/registry/test_scorer_registry.py b/tests/unit/registry/test_scorer_registry.py index 6e458ec3eb..2917b13ebe 100644 --- a/tests/unit/registry/test_scorer_registry.py +++ b/tests/unit/registry/test_scorer_registry.py @@ -1,9 +1,31 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +""" +Tests for the merged ``ScorerRegistry`` (buildable catalog + instance container) +and its introspection helpers. +""" + +import pytest from pyrit.models import ComponentIdentifier, Message, MessagePiece, Score -from pyrit.registry.object_registries.scorer_registry import ScorerRegistry +from pyrit.models.parameter import ComponentType +from pyrit.prompt_target import ( + PromptTarget, + TargetCapabilities, + TargetConfiguration, +) +from pyrit.registry.components import ( + ScorerMetadata, + ScorerRegistry, + TargetRegistry, +) +from pyrit.registry.resolution import derive_parameters +from pyrit.score import ( + SelfAskRefusalScorer, + TrueFalseCompositeScorer, + TrueFalseScoreAggregator, +) from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer import Scorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -68,332 +90,329 @@ def validate_return_scores(self, scores: list[Score]): pass -class MockGenericScorer(Scorer): - """Mock generic Scorer (not TrueFalse or FloatScale) for testing.""" +class MockChatTarget(PromptTarget): + """Minimal multi-turn capable target so LLM scorers can be built by name.""" - def __init__(self): - super().__init__(validator=DummyValidator()) + _DEFAULT_CONFIGURATION: TargetConfiguration = TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_system_prompt=True, + supports_editable_history=True, + ) + ) - def _build_identifier(self) -> ComponentIdentifier: - """Build the scorer evaluation identifier for this mock scorer. + def __init__(self, *, model_name: str = "mock_model") -> None: + super().__init__(model_name=model_name) - Returns: - ComponentIdentifier: The identifier for this scorer. - """ - return self._create_identifier() + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + return [MessagePiece(role="assistant", original_value="mock response").to_message()] - async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: - return [] + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: + pass - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: - return [] - def validate_return_scores(self, scores: list[Score]): - pass +@pytest.fixture +def registry(): + """Provide a fresh ``ScorerRegistry`` singleton, reset around each test.""" + ScorerRegistry.reset_registry_singleton() + instance = ScorerRegistry.get_registry_singleton() + yield instance + ScorerRegistry.reset_registry_singleton() - def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: - return [ - Score( - score_value="false", - score_value_description="Mock fallback", - score_type="true_false", - score_category=None, - score_metadata=None, - score_rationale="Mock fallback", - scorer_class_identifier=self.get_identifier(), - message_piece_id=message.message_pieces[0].id or "test-id", - objective=objective, - ) - ] - def get_scorer_metrics(self): - return None +# --------------------------------------------------------------------------- +# Instance container (reached via the ``instances`` property) +# --------------------------------------------------------------------------- class TestScorerRegistrySingleton: """Tests for the singleton pattern in ScorerRegistry.""" def setup_method(self): - """Reset the singleton before each test.""" - ScorerRegistry.reset_instance() + ScorerRegistry.reset_registry_singleton() def teardown_method(self): - """Reset the singleton after each test.""" - ScorerRegistry.reset_instance() + ScorerRegistry.reset_registry_singleton() def test_get_registry_singleton_returns_same_instance(self): - """Test that get_registry_singleton returns the same singleton each time.""" - instance1 = ScorerRegistry.get_registry_singleton() - instance2 = ScorerRegistry.get_registry_singleton() - - assert instance1 is instance2 + assert ScorerRegistry.get_registry_singleton() is ScorerRegistry.get_registry_singleton() def test_get_registry_singleton_returns_scorer_registry_type(self): - """Test that get_registry_singleton returns a ScorerRegistry instance.""" - instance = ScorerRegistry.get_registry_singleton() - assert isinstance(instance, ScorerRegistry) + assert isinstance(ScorerRegistry.get_registry_singleton(), ScorerRegistry) - def test_reset_instance_clears_singleton(self): - """Test that reset_instance clears the singleton.""" + def test_reset_registry_singleton_clears_singleton(self): instance1 = ScorerRegistry.get_registry_singleton() - ScorerRegistry.reset_instance() - instance2 = ScorerRegistry.get_registry_singleton() - - assert instance1 is not instance2 + ScorerRegistry.reset_registry_singleton() + assert ScorerRegistry.get_registry_singleton() is not instance1 +@pytest.mark.usefixtures("patch_central_database") class TestScorerRegistryRegisterInstance: - """Tests for register_instance functionality in ScorerRegistry.""" + """Tests for instance registration via the ``instances`` property.""" - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ScorerRegistry.reset_instance() - self.registry = ScorerRegistry.get_registry_singleton() - - def teardown_method(self): - """Reset the singleton after each test.""" - ScorerRegistry.reset_instance() - - def test_register_instance_with_custom_name(self): - """Test registering a scorer with a custom name.""" + def test_register_instance_with_custom_name(self, registry: ScorerRegistry): scorer = MockTrueFalseScorer() - self.registry.register_instance(scorer, name="custom_scorer") + registry.instances.register(scorer, name="custom_scorer") - assert "custom_scorer" in self.registry - assert self.registry.get("custom_scorer") is scorer + assert "custom_scorer" in registry.instances + assert registry.instances.get("custom_scorer") is scorer - def test_register_instance_generates_name_from_class(self): - """Test that register_instance generates a name from class name when not provided.""" + def test_register_instance_generates_name_from_class(self, registry: ScorerRegistry): scorer = MockTrueFalseScorer() - self.registry.register_instance(scorer) + registry.instances.register(scorer) - # Name should be derived from class name with hash suffix - names = self.registry.get_names() + names = registry.instances.get_names() assert len(names) == 1 assert names[0].startswith("MockTrueFalseScorer::") - def test_register_instance_multiple_scorers_unique_names(self): - """Test registering multiple scorers generates unique names.""" - scorer1 = MockTrueFalseScorer() - scorer2 = MockFloatScaleScorer() + def test_register_instance_multiple_scorers_unique_names(self, registry: ScorerRegistry): + registry.instances.register(MockTrueFalseScorer()) + registry.instances.register(MockFloatScaleScorer()) - self.registry.register_instance(scorer1) - self.registry.register_instance(scorer2) + assert len(registry.instances) == 2 - assert len(self.registry) == 2 + def test_register_instance_duplicate_name_overwrites(self, registry: ScorerRegistry): + first = MockTrueFalseScorer() + second = MockTrueFalseScorer() - def test_register_instance_same_scorer_type_different_hash(self): - """Test that same scorer class can be registered with different identifiers.""" - scorer1 = MockTrueFalseScorer() - scorer2 = MockTrueFalseScorer() + registry.instances.register(first, name="same_name") + registry.instances.register(second, name="same_name") - # Register with explicit names since scorers may have same hash - self.registry.register_instance(scorer1, name="scorer_1") - self.registry.register_instance(scorer2, name="scorer_2") + assert len(registry.instances) == 1 + assert registry.instances.get("same_name") is second - assert len(self.registry) == 2 + def test_register_instance_rejects_non_scorer(self, registry: ScorerRegistry): + class NotAScorer: + pass + with pytest.raises(TypeError, match="Scorer"): + registry.instances.register(NotAScorer()) # type: ignore[arg-type] -class TestScorerRegistryGetInstanceByName: - """Tests for get_instance_by_name functionality in ScorerRegistry.""" + assert len(registry.instances) == 0 - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ScorerRegistry.reset_instance() - self.registry = ScorerRegistry.get_registry_singleton() - self.scorer = MockTrueFalseScorer() - self.registry.register_instance(self.scorer, name="test_scorer") - - def teardown_method(self): - """Reset the singleton after each test.""" - ScorerRegistry.reset_instance() - - def test_get_instance_by_name_returns_scorer(self): - """Test getting a registered scorer by name.""" - result = self.registry.get_instance_by_name("test_scorer") - assert result is self.scorer - def test_get_instance_by_name_nonexistent_returns_none(self): - """Test that getting a non-existent scorer returns None.""" - result = self.registry.get_instance_by_name("nonexistent") - assert result is None +@pytest.mark.usefixtures("patch_central_database") +class TestScorerRegistryGetInstanceByName: + """Tests for instance lookup via ``instances.get``.""" + def test_get_instance_by_name_returns_scorer(self, registry: ScorerRegistry): + scorer = MockTrueFalseScorer() + registry.instances.register(scorer, name="test_scorer") + assert registry.instances.get("test_scorer") is scorer -class TestScorerRegistryBuildMetadata: - """Tests for _build_metadata functionality in ScorerRegistry.""" + def test_get_instance_by_name_nonexistent_returns_none(self, registry: ScorerRegistry): + assert registry.instances.get("nonexistent") is None - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ScorerRegistry.reset_instance() - self.registry = ScorerRegistry.get_registry_singleton() - def teardown_method(self): - """Reset the singleton after each test.""" - ScorerRegistry.reset_instance() +@pytest.mark.usefixtures("patch_central_database") +class TestScorerRegistryInstanceMetadata: + """Tests for instance-level metadata (``instances.list_metadata``).""" - def test_build_metadata_true_false_scorer(self): - """Test that metadata correctly identifies TrueFalseScorer type.""" + def test_instance_metadata_is_component_identifier(self, registry: ScorerRegistry): scorer = MockTrueFalseScorer() - self.registry.register_instance(scorer, name="tf_scorer") + registry.instances.register(scorer, name="tf_scorer") - metadata = self.registry.list_metadata() + metadata = registry.instances.list_metadata() assert len(metadata) == 1 - assert metadata[0].params["scorer_type"] == "true_false" + assert isinstance(metadata[0], ComponentIdentifier) assert metadata[0].class_name == "MockTrueFalseScorer" - # unique_name is auto-computed from class_name, not the registry key - assert "MockTrueFalseScorer::" in metadata[0].unique_name - def test_build_metadata_float_scale_scorer(self): - """Test that metadata correctly identifies FloatScaleScorer type.""" - scorer = MockFloatScaleScorer() - self.registry.register_instance(scorer, name="fs_scorer") + def test_instance_metadata_filter_by_class_name(self, registry: ScorerRegistry): + registry.instances.register(MockTrueFalseScorer(), name="tf1") + registry.instances.register(MockTrueFalseScorer(), name="tf2") + registry.instances.register(MockFloatScaleScorer(), name="fs1") - metadata = self.registry.list_metadata() - assert len(metadata) == 1 - assert metadata[0].params["scorer_type"] == "float_scale" - assert metadata[0].class_name == "MockFloatScaleScorer" + tf_metadata = registry.instances.list_metadata(include_filters={"class_name": "MockTrueFalseScorer"}) + assert len(tf_metadata) == 2 + assert all(m.class_name == "MockTrueFalseScorer" for m in tf_metadata) - def test_build_metadata_unknown_scorer_type(self): - """Test that non-standard scorers get 'unknown' scorer_type.""" - scorer = MockGenericScorer() - self.registry.register_instance(scorer, name="generic_scorer") - metadata = self.registry.list_metadata() - assert len(metadata) == 1 - assert metadata[0].params["scorer_type"] == "unknown" +@pytest.mark.usefixtures("patch_central_database") +class TestScorerRegistryContainerProtocol: + """Tests for the ``instances`` container protocol surface.""" - def test_build_metadata_is_component_identifier(self): - """Test that metadata is the scorer's ComponentIdentifier.""" - scorer = MockTrueFalseScorer() - self.registry.register_instance(scorer, name="tf_scorer") + def test_contains_and_len_and_iter(self, registry: ScorerRegistry): + registry.instances.register(MockTrueFalseScorer(), name="test_scorer") + assert "test_scorer" in registry.instances + assert "unknown_scorer" not in registry.instances + assert len(registry.instances) == 1 + assert "test_scorer" in list(registry.instances) - metadata = self.registry.list_metadata() - assert isinstance(metadata[0], ComponentIdentifier) - assert metadata[0] == scorer.get_identifier() + def test_get_names_returns_sorted_list(self, registry: ScorerRegistry): + registry.instances.register(MockFloatScaleScorer(), name="zeta_scorer") + registry.instances.register(MockFloatScaleScorer(), name="alpha_scorer") + assert registry.instances.get_names() == ["alpha_scorer", "zeta_scorer"] + def test_get_all_instances_returns_all(self, registry: ScorerRegistry): + tf = MockTrueFalseScorer() + fs = MockFloatScaleScorer() + registry.instances.register(tf, name="tf") + registry.instances.register(fs, name="fs") -class TestScorerRegistryListMetadataFiltering: - """Tests for list_metadata filtering in ScorerRegistry.""" + entry_map = {e.name: e for e in registry.instances.get_all_instances()} + assert entry_map["tf"].instance is tf + assert entry_map["fs"].instance is fs - def setup_method(self): - """Reset and get a fresh registry with multiple scorers.""" - ScorerRegistry.reset_instance() - self.registry = ScorerRegistry.get_registry_singleton() + def test_get_by_tag_returns_tagged_entries(self, registry: ScorerRegistry): + registry.instances.register(MockTrueFalseScorer(), name="tagged", tags=["best"]) + registry.instances.register(MockTrueFalseScorer(), name="untagged") - self.tf_scorer1 = MockTrueFalseScorer() - self.tf_scorer2 = MockTrueFalseScorer() - self.fs_scorer = MockFloatScaleScorer() + entries = registry.instances.get_by_tag(tag="best") + assert [e.name for e in entries] == ["tagged"] - self.registry.register_instance(self.tf_scorer1, name="tf_scorer_1") - self.registry.register_instance(self.tf_scorer2, name="tf_scorer_2") - self.registry.register_instance(self.fs_scorer, name="fs_scorer") - def teardown_method(self): - """Reset the singleton after each test.""" - ScorerRegistry.reset_instance() +# --------------------------------------------------------------------------- +# Buildable class catalog (discovery + introspection + build) +# --------------------------------------------------------------------------- - def test_list_metadata_filter_by_scorer_type(self): - """Test filtering metadata by scorer_type.""" - tf_metadata = self.registry.list_metadata(include_filters={"scorer_type": "true_false"}) - assert len(tf_metadata) == 2 - assert all(m.params["scorer_type"] == "true_false" for m in tf_metadata) - - fs_metadata = self.registry.list_metadata(include_filters={"scorer_type": "float_scale"}) - assert len(fs_metadata) == 1 - assert fs_metadata[0].params["scorer_type"] == "float_scale" - - def test_list_metadata_filter_by_class_name(self): - """Test filtering metadata by class_name.""" - metadata = self.registry.list_metadata(include_filters={"class_name": "MockTrueFalseScorer"}) - assert len(metadata) == 2 - assert all(m.class_name == "MockTrueFalseScorer" for m in metadata) - - def test_list_metadata_no_filter_returns_all(self): - """Test that list_metadata without filters returns all items.""" - metadata = self.registry.list_metadata() - assert len(metadata) == 3 - - def test_list_metadata_exclude_by_scorer_type(self): - """Test excluding metadata by scorer_type.""" - metadata = self.registry.list_metadata(exclude_filters={"scorer_type": "true_false"}) - assert len(metadata) == 1 - assert metadata[0].params["scorer_type"] == "float_scale" - - def test_list_metadata_combined_include_and_exclude(self): - """Test combined include and exclude filters.""" - # Filter to include true_false scorers, exclude float_scale - # This tests that both filters work together - metadata = self.registry.list_metadata( - include_filters={"scorer_type": "true_false"}, - exclude_filters={"scorer_type": "float_scale"}, - ) - # Should return both true_false scorers (exclude filter doesn't match any of them) - assert len(metadata) == 2 - assert all(m.params["scorer_type"] == "true_false" for m in metadata) - - # Test excluding by class_name - metadata = self.registry.list_metadata( - include_filters={"scorer_type": "true_false"}, - exclude_filters={"class_name": "MockTrueFalseScorer"}, - ) - # Should return 0 since all true_false scorers are MockTrueFalseScorer - assert len(metadata) == 0 +class TestDiscovery: + """Tests for scorer class discovery.""" -class TestScorerRegistryInheritedMethods: - """Tests for inherited methods from RetrievableInstanceRegistry.""" + def test_discovers_known_scorers(self, registry: ScorerRegistry): + names = registry.get_class_names() + assert "SelfAskRefusalScorer" in names + assert "TrueFalseCompositeScorer" in names - def setup_method(self): - """Reset and get a fresh registry.""" - ScorerRegistry.reset_instance() - self.registry = ScorerRegistry.get_registry_singleton() - self.scorer = MockTrueFalseScorer() - self.registry.register_instance(self.scorer, name="test_scorer") + def test_does_not_register_base_class(self, registry: ScorerRegistry): + assert "Scorer" not in registry.get_class_names() + assert "TrueFalseScorer" not in registry.get_class_names() - def teardown_method(self): - """Reset the singleton after each test.""" - ScorerRegistry.reset_instance() + def test_keyed_by_exact_class_name(self, registry: ScorerRegistry): + names = registry.get_class_names() + assert "SelfAskRefusalScorer" in names + assert "self_ask_refusal_scorer" not in names + + +class TestGetClass: + """Tests for get_class (the inherited class-catalog accessor).""" + + def test_returns_class(self, registry: ScorerRegistry): + assert registry.get_class("SelfAskRefusalScorer") is SelfAskRefusalScorer - def test_contains_registered_name(self): - """Test __contains__ for registered name.""" - assert "test_scorer" in self.registry + def test_unknown_type_raises(self, registry: ScorerRegistry): + with pytest.raises(KeyError, match="not found"): + registry.get_class("NotARealScorer") - def test_contains_unregistered_name(self): - """Test __contains__ for unregistered name.""" - assert "unknown_scorer" not in self.registry + def test_is_subclass_relationship(self, registry: ScorerRegistry): + assert issubclass(registry.get_class("SelfAskRefusalScorer"), Scorer) - def test_len_returns_count(self): - """Test __len__ returns correct count.""" - assert len(self.registry) == 1 - def test_iter_yields_names(self): - """Test __iter__ yields registered names.""" - names = list(self.registry) - assert "test_scorer" in names +@pytest.mark.usefixtures("patch_central_database") +class TestCreateLLMScorer: + """Tests that LLM scorers are buildable by resolving a target by name.""" - def test_get_names_returns_sorted_list(self): - """Test get_names returns sorted list of names.""" - self.registry.register_instance(MockFloatScaleScorer(), name="alpha_scorer") - self.registry.register_instance(MockFloatScaleScorer(), name="zeta_scorer") + def test_build_llm_scorer_resolves_chat_target_by_name(self, registry: ScorerRegistry): + target = MockChatTarget() + TargetRegistry.reset_registry_singleton() + TargetRegistry.get_registry_singleton().instances.register(target, name="scorer_target") + try: + scorer = registry.create_instance("SelfAskRefusalScorer", chat_target="scorer_target") + assert isinstance(scorer, SelfAskRefusalScorer) + assert scorer.get_chat_target() is target + finally: + TargetRegistry.reset_registry_singleton() - names = self.registry.get_names() - assert names == ["alpha_scorer", "test_scorer", "zeta_scorer"] + def test_build_llm_scorer_unknown_target_raises(self, registry: ScorerRegistry): + TargetRegistry.reset_registry_singleton() + try: + with pytest.raises(ValueError, match="not found"): + registry.create_instance("SelfAskRefusalScorer", chat_target="missing") + finally: + TargetRegistry.reset_registry_singleton() + def test_build_llm_scorer_list_for_scalar_reference_raises(self, registry: ScorerRegistry): + # A scalar reference (``chat_target``) must reject a list value. + with pytest.raises(ValueError, match="expected a single"): + registry.create_instance("SelfAskRefusalScorer", chat_target=["a", "b"]) -class TestComponentIdentifierInRegistry: - """Tests for ComponentIdentifier usage in scorer registry.""" - def test_component_identifier_has_scorer_type_in_params(self): - """Test that ComponentIdentifier includes scorer_type in params.""" - identifier = ComponentIdentifier( - class_name="TestScorer", - class_module="test.module", - params={"scorer_type": "true_false"}, +@pytest.mark.usefixtures("patch_central_database") +class TestCreateCompositeScorer: + """Tests the list-aware SCORER reference path (composite from a list of names).""" + + def test_build_composite_resolves_sub_scorers_by_name(self, registry: ScorerRegistry): + registry.instances.register(MockTrueFalseScorer(), name="s1") + registry.instances.register(MockTrueFalseScorer(), name="s2") + + composite = registry.create_instance( + "TrueFalseCompositeScorer", + scorers=["s1", "s2"], + aggregator=TrueFalseScoreAggregator.OR, + ) + assert isinstance(composite, TrueFalseCompositeScorer) + + def test_build_composite_unknown_sub_scorer_raises(self, registry: ScorerRegistry): + registry.instances.register(MockTrueFalseScorer(), name="s1") + with pytest.raises(ValueError, match="not found"): + registry.create_instance( + "TrueFalseCompositeScorer", + scorers=["s1", "missing"], + aggregator=TrueFalseScoreAggregator.OR, + ) + + def test_build_composite_resolves_prebuilt_scorers_in_list(self, registry: ScorerRegistry): + # Passthrough path inside a list: already-built scorers pass through unchanged. + composite = registry.create_instance( + "TrueFalseCompositeScorer", + scorers=[MockTrueFalseScorer(), MockTrueFalseScorer()], + aggregator=TrueFalseScoreAggregator.OR, ) + assert isinstance(composite, TrueFalseCompositeScorer) + + def test_build_composite_mixes_names_and_instances(self, registry: ScorerRegistry): + registry.instances.register(MockTrueFalseScorer(), name="s1") + composite = registry.create_instance( + "TrueFalseCompositeScorer", + scorers=["s1", MockTrueFalseScorer()], + aggregator=TrueFalseScoreAggregator.OR, + ) + assert isinstance(composite, TrueFalseCompositeScorer) + + def test_build_composite_scalar_for_list_reference_raises(self, registry: ScorerRegistry): + registry.instances.register(MockTrueFalseScorer(), name="s1") + with pytest.raises(ValueError, match="expected a list"): + registry.create_instance( + "TrueFalseCompositeScorer", + scorers="s1", + aggregator=TrueFalseScoreAggregator.OR, + ) + + +class TestClassMetadata: + """Tests for scorer class-catalog metadata building.""" + + def _metadata_for(self, registry: ScorerRegistry, name: str) -> ScorerMetadata: + return next(m for m in registry.get_all_registered_class_metadata() if m.class_name == name) + + def test_metadata_is_scorer_metadata(self, registry: ScorerRegistry): + meta = self._metadata_for(registry, "SelfAskRefusalScorer") + assert isinstance(meta, ScorerMetadata) + assert meta.class_name == "SelfAskRefusalScorer" + + def test_is_llm_based_flag(self, registry: ScorerRegistry): + # An LLM scorer takes a ``chat_target`` (TARGET reference); a composite does not. + assert self._metadata_for(registry, "SelfAskRefusalScorer").is_llm_based is True + assert self._metadata_for(registry, "TrueFalseCompositeScorer").is_llm_based is False + + def test_composite_scorers_param_is_reference(self, registry: ScorerRegistry): + meta = self._metadata_for(registry, "TrueFalseCompositeScorer") + assert any(p.is_reference_to(ComponentType.SCORER) for p in meta.parameters) + + +class TestRegistrationGate: + """The identifier blueprint must line up with a resolvable contract for every scorer.""" + + def test_discovery_validates_all_scorers(self, registry: ScorerRegistry) -> None: + names = registry.get_class_names() + assert names + assert "SelfAskRefusalScorer" in names + + def test_is_llm_based_matches_target_reference(self, registry: ScorerRegistry) -> None: + from pyrit.models.identifiers import ScorerIdentifier - assert identifier.class_name == "TestScorer" - assert identifier.class_module == "test.module" - assert identifier.params["scorer_type"] == "true_false" - # unique_name is auto-computed - assert identifier.unique_name is not None - assert identifier.hash is not None + for meta in registry.get_all_registered_class_metadata(): + parameters = derive_parameters(cls=registry.get_class(meta.class_name), identifier_type=ScorerIdentifier) + has_target = any(p.is_reference_to(ComponentType.TARGET) for p in parameters) + assert meta.is_llm_based is has_target, f"is_llm_based mismatch for {meta.class_name}" diff --git a/tests/unit/registry/test_target_registry.py b/tests/unit/registry/test_target_registry.py index 155313b285..575c17dea1 100644 --- a/tests/unit/registry/test_target_registry.py +++ b/tests/unit/registry/test_target_registry.py @@ -1,250 +1,324 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +""" +Tests for the merged ``TargetRegistry`` (buildable catalog + instance container) +and its introspection helpers. +""" import pytest from pyrit.models import ComponentIdentifier, Message, MessagePiece -from pyrit.prompt_target import PromptTarget -from pyrit.registry.object_registries.target_registry import TargetRegistry +from pyrit.models.parameter import ComponentType +from pyrit.prompt_target import ( + PromptTarget, + RoundRobinTarget, + TargetCapabilities, + TargetConfiguration, +) +from pyrit.registry.components import TargetMetadata, TargetRegistry +from pyrit.registry.resolution import derive_parameters class MockPromptTarget(PromptTarget): - """Mock PromptTarget for testing.""" - - def __init__(self, *, model_name: str = "mock_model") -> None: - super().__init__(model_name=model_name) - - async def _send_prompt_to_target_async( - self, - *, - normalized_conversation: list[Message], - ) -> list[Message]: - return [ - MessagePiece( - role="assistant", - original_value="mock response", - ).to_message() - ] + """Minimal PromptTarget (multi-turn capable) for registry tests.""" + + _DEFAULT_CONFIGURATION: TargetConfiguration = TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_system_prompt=True, + supports_editable_history=True, + ) + ) + + def __init__(self, *, model_name: str = "mock_model", endpoint: str | None = None) -> None: + super().__init__(model_name=model_name, endpoint=endpoint) + + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + return [MessagePiece(role="assistant", original_value="mock response").to_message()] def _validate_request(self, *, normalized_conversation: list[Message]) -> None: pass class MockPromptChatTarget(PromptTarget): - """Mock chat-style target for testing conversation history support.""" + """A second mock target for multi-instance tests.""" def __init__(self, *, model_name: str = "mock_chat_model", endpoint: str = "http://chat-test") -> None: super().__init__(model_name=model_name, endpoint=endpoint) - async def _send_prompt_to_target_async( - self, - *, - normalized_conversation: list[Message], - ) -> list[Message]: - return [ - MessagePiece( - role="assistant", - original_value="chat response", - ).to_message() - ] + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + return [MessagePiece(role="assistant", original_value="chat response").to_message()] def _validate_request(self, *, normalized_conversation: list[Message]) -> None: pass +@pytest.fixture +def registry(): + """Provide a fresh ``TargetRegistry`` singleton, reset around each test.""" + TargetRegistry.reset_registry_singleton() + instance = TargetRegistry.get_registry_singleton() + yield instance + TargetRegistry.reset_registry_singleton() + + +# --------------------------------------------------------------------------- +# Instance container (reached via the ``instances`` property) +# --------------------------------------------------------------------------- + + class TestTargetRegistrySingleton: """Tests for the singleton pattern in TargetRegistry.""" def setup_method(self): - """Reset the singleton before each test.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() def teardown_method(self): - """Reset the singleton after each test.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() def test_get_registry_singleton_returns_same_instance(self): - """Test that get_registry_singleton returns the same singleton each time.""" - instance1 = TargetRegistry.get_registry_singleton() - instance2 = TargetRegistry.get_registry_singleton() - - assert instance1 is instance2 + assert TargetRegistry.get_registry_singleton() is TargetRegistry.get_registry_singleton() def test_get_registry_singleton_returns_target_registry_type(self): - """Test that get_registry_singleton returns a TargetRegistry instance.""" - instance = TargetRegistry.get_registry_singleton() - assert isinstance(instance, TargetRegistry) + assert isinstance(TargetRegistry.get_registry_singleton(), TargetRegistry) - def test_reset_instance_clears_singleton(self): - """Test that reset_instance clears the singleton.""" + def test_reset_registry_singleton_clears_singleton(self): instance1 = TargetRegistry.get_registry_singleton() - TargetRegistry.reset_instance() - instance2 = TargetRegistry.get_registry_singleton() - - assert instance1 is not instance2 + TargetRegistry.reset_registry_singleton() + assert TargetRegistry.get_registry_singleton() is not instance1 @pytest.mark.usefixtures("patch_central_database") class TestTargetRegistryRegisterInstance: - """Tests for register_instance functionality in TargetRegistry.""" + """Tests for instance registration via the ``instances`` property.""" - def setup_method(self): - """Reset and get a fresh registry for each test.""" - TargetRegistry.reset_instance() - self.registry = TargetRegistry.get_registry_singleton() - - def teardown_method(self): - """Reset the singleton after each test.""" - TargetRegistry.reset_instance() - - def test_register_instance_with_custom_name(self): - """Test registering a target with a custom name.""" + def test_register_instance_with_custom_name(self, registry: TargetRegistry): target = MockPromptTarget() - self.registry.register_instance(target, name="custom_target") + registry.instances.register(target, name="custom_target") - assert "custom_target" in self.registry - assert self.registry.get("custom_target") is target + assert "custom_target" in registry.instances + assert registry.instances.get("custom_target") is target - def test_register_instance_generates_name_from_class(self): - """Test that register_instance generates a name from class name when not provided.""" + def test_register_instance_generates_name_from_class(self, registry: TargetRegistry): target = MockPromptTarget() - self.registry.register_instance(target) + registry.instances.register(target) - # Name should be derived from class name with hash suffix - names = self.registry.get_names() + names = registry.instances.get_names() assert len(names) == 1 assert names[0].startswith("MockPromptTarget::") - def test_register_instance_multiple_targets_unique_names(self): - """Test registering multiple targets generates unique names.""" - target1 = MockPromptTarget() - target2 = MockPromptChatTarget() + def test_register_instance_multiple_targets_unique_names(self, registry: TargetRegistry): + registry.instances.register(MockPromptTarget()) + registry.instances.register(MockPromptChatTarget()) - self.registry.register_instance(target1) - self.registry.register_instance(target2) + assert len(registry.instances) == 2 - assert len(self.registry) == 2 + def test_register_instance_duplicate_name_overwrites(self, registry: TargetRegistry): + first = MockPromptTarget(model_name="first") + second = MockPromptTarget(model_name="second") - def test_register_instance_same_target_type_different_config(self): - """Test that same target class with different configs can be registered.""" - target1 = MockPromptTarget(model_name="model_a") - target2 = MockPromptTarget(model_name="model_b") + registry.instances.register(first, name="same_name") + registry.instances.register(second, name="same_name") - # Register with explicit names - self.registry.register_instance(target1, name="target_1") - self.registry.register_instance(target2, name="target_2") + assert len(registry.instances) == 1 + assert registry.instances.get("same_name") is second - assert len(self.registry) == 2 + def test_register_instance_rejects_non_target(self, registry: TargetRegistry): + class NotATarget: + pass - def test_register_instance_with_duplicate_name_silently_overwrites(self): - """Characterization: re-registering an existing name silently replaces the prior entry. + with pytest.raises(TypeError, match="PromptTarget"): + registry.instances.register(NotATarget()) # type: ignore[arg-type] - BaseInstanceRegistry.register is plain dict assignment; there is no - collision check, warning, or error. This test pins the current behavior - so any future tightening (warn, raise, idempotent skip) is an - intentional decision rather than a silent regression. Tracked as - ``duplicate-registry-name`` in failure_mode_followups for the PR - review batch. - """ - first = MockPromptTarget(model_name="first") - second = MockPromptTarget(model_name="second") + assert len(registry.instances) == 0 - self.registry.register_instance(first, name="same_name") - self.registry.register_instance(second, name="same_name") - assert len(self.registry) == 1 - assert self.registry.get("same_name") is second +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistryGetInstanceByName: + """Tests for instance lookup via ``instances.get``.""" + + def test_get_instance_by_name_returns_target(self, registry: TargetRegistry): + target = MockPromptTarget() + registry.instances.register(target, name="test_target") + assert registry.instances.get("test_target") is target + + def test_get_instance_by_name_nonexistent_returns_none(self, registry: TargetRegistry): + assert registry.instances.get("nonexistent") is None @pytest.mark.usefixtures("patch_central_database") -class TestTargetRegistryGetInstanceByName: - """Tests for get_instance_by_name functionality in TargetRegistry.""" +class TestTargetRegistryInstanceMetadata: + """Tests for instance-level metadata (``instances.list_metadata``).""" - def setup_method(self): - """Reset and get a fresh registry for each test.""" - TargetRegistry.reset_instance() - self.registry = TargetRegistry.get_registry_singleton() - self.target = MockPromptTarget() - self.registry.register_instance(self.target, name="test_target") + def test_instance_metadata_is_component_identifier(self, registry: TargetRegistry): + target = MockPromptTarget(model_name="test_model") + registry.instances.register(target, name="mock_target") - def teardown_method(self): - """Reset the singleton after each test.""" - TargetRegistry.reset_instance() + metadata = registry.instances.list_metadata() + assert len(metadata) == 1 + assert isinstance(metadata[0], ComponentIdentifier) + assert metadata[0].class_name == "MockPromptTarget" + assert metadata[0].params["model_name"] == "test_model" - def test_get_instance_by_name_returns_target(self): - """Test getting a registered target by name.""" - result = self.registry.get_instance_by_name("test_target") - assert result is self.target + def test_instance_metadata_filter_by_class_name(self, registry: TargetRegistry): + registry.instances.register(MockPromptTarget(model_name="a"), name="t1") + registry.instances.register(MockPromptTarget(model_name="b"), name="t2") + registry.instances.register(MockPromptChatTarget(), name="chat") - def test_get_instance_by_name_nonexistent_returns_none(self): - """Test that getting a non-existent target returns None.""" - result = self.registry.get_instance_by_name("nonexistent") - assert result is None + metadata = registry.instances.list_metadata(include_filters={"class_name": "MockPromptTarget"}) + assert len(metadata) == 2 + assert all(m.class_name == "MockPromptTarget" for m in metadata) @pytest.mark.usefixtures("patch_central_database") -class TestTargetRegistryBuildMetadata: - """Tests for _build_metadata functionality in TargetRegistry.""" +class TestTargetRegistryContainerProtocol: + """Tests for the ``instances`` container protocol surface.""" - def setup_method(self): - """Reset and get a fresh registry for each test.""" - TargetRegistry.reset_instance() - self.registry = TargetRegistry.get_registry_singleton() + def test_contains_and_len_and_iter(self, registry: TargetRegistry): + registry.instances.register(MockPromptTarget(), name="test_target") + assert "test_target" in registry.instances + assert "unknown_target" not in registry.instances + assert len(registry.instances) == 1 + assert "test_target" in list(registry.instances) - def teardown_method(self): - """Reset the singleton after each test.""" - TargetRegistry.reset_instance() + def test_get_names_returns_sorted_list(self, registry: TargetRegistry): + registry.instances.register(MockPromptTarget(), name="zeta_target") + registry.instances.register(MockPromptTarget(), name="alpha_target") + assert registry.instances.get_names() == ["alpha_target", "zeta_target"] - def test_build_metadata_includes_class_name(self): - """Test that metadata (ComponentIdentifier) includes the class name.""" - target = MockPromptTarget() - self.registry.register_instance(target, name="mock_target") + def test_get_all_instances_returns_all(self, registry: TargetRegistry): + a = MockPromptTarget() + b = MockPromptChatTarget() + registry.instances.register(a, name="a") + registry.instances.register(b, name="b") - metadata = self.registry.list_metadata() - assert len(metadata) == 1 - assert isinstance(metadata[0], ComponentIdentifier) - assert metadata[0].class_name == "MockPromptTarget" + entry_map = {e.name: e for e in registry.instances.get_all_instances()} + assert entry_map["a"].instance is a + assert entry_map["b"].instance is b - def test_build_metadata_includes_model_name(self): - """Test that metadata includes the model_name.""" - target = MockPromptTarget(model_name="test_model") - self.registry.register_instance(target, name="mock_target") - metadata = self.registry.list_metadata() - assert metadata[0].params["model_name"] == "test_model" +# --------------------------------------------------------------------------- +# Buildable class catalog (discovery + introspection + build) +# --------------------------------------------------------------------------- -@pytest.mark.usefixtures("patch_central_database") -class TestTargetRegistryListMetadata: - """Tests for list_metadata in TargetRegistry.""" +class TestDiscovery: + """Tests for target class discovery.""" - def setup_method(self): - """Reset and get a fresh registry with multiple targets.""" - TargetRegistry.reset_instance() - self.registry = TargetRegistry.get_registry_singleton() + def test_discovers_known_targets(self, registry: TargetRegistry): + names = registry.get_class_names() + assert "OpenAIChatTarget" in names + assert "RoundRobinTarget" in names - self.target1 = MockPromptTarget(model_name="model_a") - self.target2 = MockPromptTarget(model_name="model_b") - self.chat_target = MockPromptChatTarget() + def test_does_not_register_base_class(self, registry: TargetRegistry): + assert "PromptTarget" not in registry.get_class_names() - self.registry.register_instance(self.target1, name="target_1") - self.registry.register_instance(self.target2, name="target_2") - self.registry.register_instance(self.chat_target, name="chat_target") + def test_keyed_by_exact_class_name(self, registry: TargetRegistry): + names = registry.get_class_names() + assert "OpenAIChatTarget" in names + assert "openai_chat_target" not in names - def teardown_method(self): - """Reset the singleton after each test.""" - TargetRegistry.reset_instance() - def test_list_metadata_returns_all_registered(self): - """Test that list_metadata returns metadata for all registered targets.""" - metadata = self.registry.list_metadata() - assert len(metadata) == 3 +class TestGetClass: + """Tests for get_class (the inherited class-catalog accessor).""" + + def test_returns_class(self, registry: TargetRegistry): + assert registry.get_class("RoundRobinTarget") is RoundRobinTarget + + def test_unknown_type_raises(self, registry: TargetRegistry): + with pytest.raises(KeyError, match="not found"): + registry.get_class("NotARealTarget") - def test_list_metadata_filter_by_class_name(self): - """Test filtering metadata by class_name.""" - mock_metadata = self.registry.list_metadata(include_filters={"class_name": "MockPromptTarget"}) + def test_is_subclass_relationship(self, registry: TargetRegistry): + assert issubclass(registry.get_class("RoundRobinTarget"), PromptTarget) - assert len(mock_metadata) == 2 - for m in mock_metadata: - assert m.class_name == "MockPromptTarget" + +@pytest.mark.usefixtures("patch_central_database") +class TestCreateInstance: + """Tests for create_instance (build via the shared resolver).""" + + def test_build_round_robin_resolves_targets_by_name(self, registry: TargetRegistry): + # The list-aware resolution path: a ``list[str]`` of registry names is + # resolved element by element into the registered target instances. + registry.instances.register(MockPromptTarget(model_name="m", endpoint="http://a"), name="t1") + registry.instances.register(MockPromptTarget(model_name="m", endpoint="http://b"), name="t2") + + rr = registry.create_instance("RoundRobinTarget", targets=["t1", "t2"]) + assert isinstance(rr, RoundRobinTarget) + + def test_build_round_robin_unknown_target_raises(self, registry: TargetRegistry): + registry.instances.register(MockPromptTarget(model_name="m", endpoint="http://a"), name="t1") + with pytest.raises(ValueError, match="not found"): + registry.create_instance("RoundRobinTarget", targets=["t1", "missing"]) + + def test_build_round_robin_resolves_prebuilt_instances_in_list(self, registry: TargetRegistry): + # Passthrough path inside a list: already-built instances are passed through + # unchanged rather than looked up by name. + t1 = MockPromptTarget(model_name="m", endpoint="http://a") + t2 = MockPromptTarget(model_name="m", endpoint="http://b") + rr = registry.create_instance("RoundRobinTarget", targets=[t1, t2]) + assert isinstance(rr, RoundRobinTarget) + + def test_build_round_robin_mixes_names_and_instances(self, registry: TargetRegistry): + registry.instances.register(MockPromptTarget(model_name="m", endpoint="http://a"), name="t1") + t2 = MockPromptTarget(model_name="m", endpoint="http://b") + rr = registry.create_instance("RoundRobinTarget", targets=["t1", t2]) + assert isinstance(rr, RoundRobinTarget) + + def test_build_round_robin_scalar_for_list_reference_raises(self, registry: TargetRegistry): + registry.instances.register(MockPromptTarget(model_name="m", endpoint="http://a"), name="t1") + with pytest.raises(ValueError, match="expected a list"): + registry.create_instance("RoundRobinTarget", targets="t1") + + def test_unknown_type_raises(self, registry: TargetRegistry): + with pytest.raises(KeyError, match="not found"): + registry.create_instance("NotARealTarget") + + def test_build_does_not_register_instance(self, registry: TargetRegistry): + registry.instances.register(MockPromptTarget(model_name="m", endpoint="http://a"), name="t1") + registry.instances.register(MockPromptTarget(model_name="m", endpoint="http://b"), name="t2") + registry.create_instance("RoundRobinTarget", targets=["t1", "t2"]) + # The two pre-registered targets remain; the built RR is not auto-registered. + assert len(registry.instances) == 2 + + +class TestClassMetadata: + """Tests for target class-catalog metadata building.""" + + def _metadata_for(self, registry: TargetRegistry, name: str) -> TargetMetadata: + return next(m for m in registry.get_all_registered_class_metadata() if m.class_name == name) + + def test_metadata_is_target_metadata(self, registry: TargetRegistry): + meta = self._metadata_for(registry, "RoundRobinTarget") + assert isinstance(meta, TargetMetadata) + assert meta.class_name == "RoundRobinTarget" + + def test_round_robin_targets_param_is_reference(self, registry: TargetRegistry): + meta = self._metadata_for(registry, "RoundRobinTarget") + assert any(p.is_reference_to(ComponentType.TARGET) for p in meta.parameters) + + +class TestRegistrationGate: + """The identifier blueprint must line up with a resolvable contract for every target.""" + + def test_discovery_validates_all_targets(self, registry: TargetRegistry) -> None: + # Discovery registers every target through ``register_class``, which validates + # each class; accessing the catalog therefore proves every target is buildable. + names = registry.get_class_names() + assert names + assert "RoundRobinTarget" in names + + def test_every_target_reference_maps_to_a_wired_registry(self, registry: TargetRegistry) -> None: + from pyrit.models.identifiers import TargetIdentifier + + for name in registry.get_class_names(): + parameters = derive_parameters(cls=registry.get_class(name), identifier_type=TargetIdentifier) + for param in parameters: + if param.reference is not None: + assert param.reference.component_type in ( + ComponentType.TARGET, + ComponentType.CONVERTER, + ComponentType.SCORER, + ) diff --git a/tests/unit/scenario/airt/test_cyber.py b/tests/unit/scenario/airt/test_cyber.py index de86aa365e..40e4599207 100644 --- a/tests/unit/scenario/airt/test_cyber.py +++ b/tests/unit/scenario/airt/test_cyber.py @@ -73,19 +73,19 @@ def reset_technique_registry(): from pyrit.scenario.scenarios.airt.cyber import _build_cyber_strategy AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() _build_cyber_strategy.cache_clear() adv_target = MagicMock(spec=PromptTarget) adv_target.capabilities.includes.return_value = True target_registry = TargetRegistry.get_registry_singleton() - target_registry.register_instance(adv_target, name="adversarial_chat") + target_registry.instances.register(adv_target, name="adversarial_chat") technique_registry = AttackTechniqueRegistry.get_registry_singleton() technique_registry.register_from_factories(build_scenario_technique_factories()) yield AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() _build_cyber_strategy.cache_clear() diff --git a/tests/unit/scenario/airt/test_leakage.py b/tests/unit/scenario/airt/test_leakage.py index 77a791e41b..42f70f642e 100644 --- a/tests/unit/scenario/airt/test_leakage.py +++ b/tests/unit/scenario/airt/test_leakage.py @@ -90,18 +90,18 @@ def mock_objective_scorer(): def reset_technique_registry(): """Reset registries and populate scenario factories for each test.""" AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() _build_leakage_strategy.cache_clear() adv_target = MagicMock(spec=PromptTarget) adv_target.capabilities.includes.return_value = True - TargetRegistry.get_registry_singleton().register_instance(adv_target, name="adversarial_chat") + TargetRegistry.get_registry_singleton().instances.register(adv_target, name="adversarial_chat") technique_registry = AttackTechniqueRegistry.get_registry_singleton() technique_registry.register_from_factories(build_scenario_technique_factories()) yield AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() _build_leakage_strategy.cache_clear() diff --git a/tests/unit/scenario/airt/test_rapid_response.py b/tests/unit/scenario/airt/test_rapid_response.py index 42a1059138..82b5387fec 100644 --- a/tests/unit/scenario/airt/test_rapid_response.py +++ b/tests/unit/scenario/airt/test_rapid_response.py @@ -89,18 +89,18 @@ def reset_technique_registry(): from pyrit.scenario.scenarios.airt.rapid_response import _build_rapid_response_strategy AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() _build_rapid_response_strategy.cache_clear() adv_target = MagicMock(spec=PromptTarget) adv_target.capabilities.includes.return_value = True - TargetRegistry.get_registry_singleton().register_instance(adv_target, name="adversarial_chat") + TargetRegistry.get_registry_singleton().instances.register(adv_target, name="adversarial_chat") technique_registry = AttackTechniqueRegistry.get_registry_singleton() technique_registry.register_from_factories(build_scenario_technique_factories()) yield AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() _build_rapid_response_strategy.cache_clear() diff --git a/tests/unit/scenario/benchmark/test_adversarial.py b/tests/unit/scenario/benchmark/test_adversarial.py index bebe095b11..a0299df1cd 100644 --- a/tests/unit/scenario/benchmark/test_adversarial.py +++ b/tests/unit/scenario/benchmark/test_adversarial.py @@ -71,14 +71,14 @@ def _build_benchmarkable_factories_snapshot() -> list: factory construction does not depend on environment variables, then filters by the same predicate used in ``AdversarialBenchmark._get_benchmarkable_factories``. """ - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() adv = MagicMock(spec=PromptTarget) adv.capabilities.includes.return_value = True - TargetRegistry.get_registry_singleton().register_instance(adv, name="adversarial_chat") + TargetRegistry.get_registry_singleton().instances.register(adv, name="adversarial_chat") try: factories = build_scenario_technique_factories() finally: - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() return [f for f in factories if f.uses_adversarial and "core" in f.strategy_tags] @@ -102,17 +102,17 @@ def reset_technique_registry(): because our implementation uses ``@cache`` (not ``_cached_strategy_class``). """ AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() _build_benchmark_strategy.cache_clear() adv_target = MagicMock(spec=PromptTarget) adv_target.capabilities.includes.return_value = True - TargetRegistry.get_registry_singleton().register_instance(adv_target, name="adversarial_chat") + TargetRegistry.get_registry_singleton().instances.register(adv_target, name="adversarial_chat") AttackTechniqueRegistry.get_registry_singleton().register_from_factories(build_scenario_technique_factories()) yield AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() _build_benchmark_strategy.cache_clear() @@ -120,7 +120,7 @@ def _register_adversarial_target(*, name: str) -> PromptTarget: """Register a mock adversarial target in TargetRegistry.""" target = MagicMock(spec=PromptTarget) registry = TargetRegistry.get_registry_singleton() - registry.register_instance(target, name=name) + registry.instances.register(target, name=name) return target @@ -463,7 +463,7 @@ async def test_display_group_uses_registry_name_not_target_model_name(self): target._underlying_model = "another-model-identity" target._endpoint = "https://hijacked.example.com/openai/v1" target.name = "name-attribute-that-must-not-leak" - TargetRegistry.get_registry_singleton().register_instance(target, name="adv_a") + TargetRegistry.get_registry_singleton().instances.register(target, name="adv_a") # Reset the technique registry to get a controllable mock factory AttackTechniqueRegistry.reset_instance() _build_benchmark_strategy.cache_clear() @@ -499,8 +499,8 @@ async def test_factory_create_called_per_target_with_adversarial_chat(self): # 1 factory × 2 targets × 1 dataset = 2 create calls assert factory.create.call_count == 2 - target_a = TargetRegistry.get_registry_singleton().get_instance_by_name("adv_a") - target_b = TargetRegistry.get_registry_singleton().get_instance_by_name("adv_b") + target_a = TargetRegistry.get_registry_singleton().instances.get("adv_a") + target_b = TargetRegistry.get_registry_singleton().instances.get("adv_b") injected_targets = {call.kwargs["adversarial_chat"] for call in factory.create.call_args_list} assert injected_targets == {target_a, target_b} diff --git a/tests/unit/scenario/core/test_baseline_deprecation.py b/tests/unit/scenario/core/test_baseline_deprecation.py index f23da82177..2e63e50bf3 100644 --- a/tests/unit/scenario/core/test_baseline_deprecation.py +++ b/tests/unit/scenario/core/test_baseline_deprecation.py @@ -124,17 +124,17 @@ def _populate_registry(self): from pyrit.setup.initializers.components.scenario_techniques import build_scenario_technique_factories AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() Cyber._cached_strategy_class = None adv_target = MagicMock(spec=PromptTarget) adv_target.capabilities.includes.return_value = True - TargetRegistry.get_registry_singleton().register_instance(adv_target, name="adversarial_chat") + TargetRegistry.get_registry_singleton().instances.register(adv_target, name="adversarial_chat") AttackTechniqueRegistry.get_registry_singleton().register_from_factories(build_scenario_technique_factories()) yield AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() Cyber._cached_strategy_class = None @pytest.mark.parametrize( diff --git a/tests/unit/scenario/core/test_scenario.py b/tests/unit/scenario/core/test_scenario.py index c5e886944f..e717c96ad5 100644 --- a/tests/unit/scenario/core/test_scenario.py +++ b/tests/unit/scenario/core/test_scenario.py @@ -875,7 +875,7 @@ def test_returns_registry_scorer_when_tagged(self, mock_registry_cls) -> None: mock_entry.instance = mock_scorer mock_registry = MagicMock() - mock_registry.get_by_tag.return_value = [mock_entry] + mock_registry.instances.get_by_tag.return_value = [mock_entry] mock_registry_cls.get_registry_singleton.return_value = mock_registry # Mock self with _get_additional_scoring_questions returning empty sequence @@ -892,7 +892,7 @@ def test_returns_fallback_when_registry_empty(self, mock_registry_cls, mock_get_ from pyrit.score import TrueFalseInverterScorer mock_registry = MagicMock() - mock_registry.get_by_tag.return_value = [] + mock_registry.instances.get_by_tag.return_value = [] mock_registry_cls.get_registry_singleton.return_value = mock_registry # Mock self with _get_additional_scoring_questions returning empty sequence diff --git a/tests/unit/scenario/core/test_scenario_strategy_invariants.py b/tests/unit/scenario/core/test_scenario_strategy_invariants.py index 5b363f4315..67b942ebdb 100644 --- a/tests/unit/scenario/core/test_scenario_strategy_invariants.py +++ b/tests/unit/scenario/core/test_scenario_strategy_invariants.py @@ -39,17 +39,17 @@ def _reset_registries(): from pyrit.setup.initializers.components.scenario_techniques import build_scenario_technique_factories AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() Cyber._cached_strategy_class = None RapidResponse._cached_strategy_class = None adv_target = MagicMock(spec=PromptTarget) adv_target.capabilities.includes.return_value = True - TargetRegistry.get_registry_singleton().register_instance(adv_target, name="adversarial_chat") + TargetRegistry.get_registry_singleton().instances.register(adv_target, name="adversarial_chat") AttackTechniqueRegistry.get_registry_singleton().register_from_factories(build_scenario_technique_factories()) yield AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() Cyber._cached_strategy_class = None RapidResponse._cached_strategy_class = None diff --git a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py index 848851ae0a..5fc2e6f89a 100644 --- a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py +++ b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py @@ -50,11 +50,11 @@ def reset_technique_registry(): from pyrit.registry import TargetRegistry AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() TextAdaptive._cached_strategy_class = None yield AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() TextAdaptive._cached_strategy_class = None diff --git a/tests/unit/scenario/test_package_lazy_attrs.py b/tests/unit/scenario/test_package_lazy_attrs.py index bcc90e946b..7af03f0c8e 100644 --- a/tests/unit/scenario/test_package_lazy_attrs.py +++ b/tests/unit/scenario/test_package_lazy_attrs.py @@ -22,7 +22,7 @@ def populate_registries(): """Populate the technique + target registries so lazy strategy builders succeed.""" AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() _build_cyber_strategy.cache_clear() _build_leakage_strategy.cache_clear() _build_rapid_response_strategy.cache_clear() @@ -30,12 +30,12 @@ def populate_registries(): adv_target = MagicMock(spec=PromptTarget) adv_target.capabilities.includes.return_value = True - TargetRegistry.get_registry_singleton().register_instance(adv_target, name="adversarial_chat") + TargetRegistry.get_registry_singleton().instances.register(adv_target, name="adversarial_chat") AttackTechniqueRegistry.get_registry_singleton().register_from_factories(build_scenario_technique_factories()) yield AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() _build_cyber_strategy.cache_clear() _build_leakage_strategy.cache_clear() _build_rapid_response_strategy.cache_clear() diff --git a/tests/unit/setup/test_load_default_datasets.py b/tests/unit/setup/test_load_default_datasets.py index f3ecd53d5f..5867ffde64 100644 --- a/tests/unit/setup/test_load_default_datasets.py +++ b/tests/unit/setup/test_load_default_datasets.py @@ -24,16 +24,16 @@ def populated_technique_registry(): """Populate the technique + target registries so scenario metadata building succeeds.""" AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() adv_target = MagicMock(spec=PromptTarget) adv_target.capabilities.includes.return_value = True - TargetRegistry.get_registry_singleton().register_instance(adv_target, name="adversarial_chat") + TargetRegistry.get_registry_singleton().instances.register(adv_target, name="adversarial_chat") AttackTechniqueRegistry.get_registry_singleton().register_from_factories(build_scenario_technique_factories()) yield AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() @dataclass diff --git a/tests/unit/setup/test_scenario_techniques_initializer.py b/tests/unit/setup/test_scenario_techniques_initializer.py index 8878df8cf0..d3bebd15f2 100644 --- a/tests/unit/setup/test_scenario_techniques_initializer.py +++ b/tests/unit/setup/test_scenario_techniques_initializer.py @@ -35,10 +35,10 @@ def reset_registries(): """Reset technique and target registries between tests.""" AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() yield AttackTechniqueRegistry.reset_instance() - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() @pytest.fixture @@ -48,7 +48,7 @@ def mock_adversarial_target(): # capabilities check inside get_default_adversarial_target requires multi_turn support target.capabilities.includes.return_value = True registry = TargetRegistry.get_registry_singleton() - registry.register_instance(target, name="adversarial_chat") + registry.instances.register(target, name="adversarial_chat") return target diff --git a/tests/unit/setup/test_scorer_initializer.py b/tests/unit/setup/test_scorer_initializer.py index 67884d284a..e9b0234ebf 100644 --- a/tests/unit/setup/test_scorer_initializer.py +++ b/tests/unit/setup/test_scorer_initializer.py @@ -54,14 +54,14 @@ class TestScorerInitializerInitialize: def setup_method(self) -> None: """Reset registries before each test.""" - ScorerRegistry.reset_instance() - TargetRegistry.reset_instance() + ScorerRegistry.reset_registry_singleton() + TargetRegistry.reset_registry_singleton() self._clear_env_vars() def teardown_method(self) -> None: """Clean up after each test.""" - ScorerRegistry.reset_instance() - TargetRegistry.reset_instance() + ScorerRegistry.reset_registry_singleton() + TargetRegistry.reset_registry_singleton() self._clear_env_vars() def _clear_env_vars(self) -> None: @@ -92,7 +92,7 @@ def _register_mock_target(self, *, name: str, underlying_model: str = "gpt-4o") }, ) registry = TargetRegistry.get_registry_singleton() - registry.register_instance(target, name=name) + registry.instances.register(target, name=name) return target def _register_all_scorer_targets(self) -> None: @@ -121,7 +121,7 @@ async def test_registers_all_scorer_variants(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - assert len(registry) == 28 + assert len(registry.instances) == 28 async def test_registers_gpt4o_scorers_when_only_gpt4o_targets(self) -> None: """Test that GPT4O-based scorers register when only GPT4O targets are available.""" @@ -133,9 +133,9 @@ async def test_registers_gpt4o_scorers_when_only_gpt4o_targets(self) -> None: registry = ScorerRegistry.get_registry_singleton() # Normal mode: falls back to gpt4o refusal - assert registry.get_instance_by_name("refusal_gpt4o_objective_strict") is not None + assert registry.instances.get("refusal_gpt4o_objective_strict") is not None # inverted_refusal uses the gpt4o refusal fallback - assert registry.get_instance_by_name("inverted_refusal") is not None + assert registry.instances.get("inverted_refusal") is not None async def test_refusal_scorers_registered(self) -> None: """Test that refusal scorers are registered when gpt4o target is available.""" @@ -145,7 +145,7 @@ async def test_refusal_scorers_registered(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - refusal_entries = registry.get_by_tag(tag=ScorerInitializerTags.REFUSAL) + refusal_entries = registry.instances.get_by_tag(tag=ScorerInitializerTags.REFUSAL) # 4 gpt4o prompt-template variants (gpt5_4, gpt5_1, unsafe skipped) assert len(refusal_entries) == 4 @@ -159,10 +159,10 @@ async def test_acs_scorers_registered_when_env_vars_set(self) -> None: registry = ScorerRegistry.get_registry_singleton() # 3 threshold + 4 harm = 7 ACS total - assert registry.get_instance_by_name("acs_threshold_05") is not None - assert registry.get_instance_by_name("acs_hate") is not None + assert registry.instances.get("acs_threshold_05") is not None + assert registry.instances.get("acs_hate") is not None - acs_entries = registry.get_by_tag(tag=ScorerInitializerTags.ACS) + acs_entries = registry.instances.get_by_tag(tag=ScorerInitializerTags.ACS) assert len(acs_entries) == 7 async def test_acs_scorers_skipped_without_env_vars(self) -> None: @@ -173,9 +173,9 @@ async def test_acs_scorers_skipped_without_env_vars(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - assert registry.get_instance_by_name("acs_threshold_01") is None - assert registry.get_instance_by_name("acs_threshold_05") is None - assert registry.get_instance_by_name("acs_hate") is None + assert registry.instances.get("acs_threshold_01") is None + assert registry.instances.get("acs_threshold_05") is None + assert registry.instances.get("acs_hate") is None async def test_likert_scorers_registered(self) -> None: """Test that likert scorers are registered for LikertScalePaths with evaluation files.""" @@ -188,7 +188,7 @@ async def test_likert_scorers_registered(self) -> None: for scale in LikertScalePaths: if scale.evaluation_files is not None: expected_name = f"likert_{scale.name.lower().removesuffix('_scale')}_gpt4o" - scorer = registry.get_instance_by_name(expected_name) + scorer = registry.instances.get(expected_name) assert scorer is not None, f"Likert scorer '{expected_name}' not found in registry" async def test_gracefully_skips_scorers_with_missing_target(self) -> None: @@ -200,11 +200,11 @@ async def test_gracefully_skips_scorers_with_missing_target(self) -> None: registry = ScorerRegistry.get_registry_singleton() # Refusal variants requiring missing targets should be skipped - assert registry.get_instance_by_name("refusal_gpt5_4") is None - assert registry.get_instance_by_name("refusal_gpt5_1") is None - assert registry.get_instance_by_name("refusal_gpt4o_unsafe") is None + assert registry.instances.get("refusal_gpt5_4") is None + assert registry.instances.get("refusal_gpt5_1") is None + assert registry.instances.get("refusal_gpt4o_unsafe") is None # But gpt4o-based ones should register - assert registry.get_instance_by_name("refusal_gpt4o_objective_lenient") is not None + assert registry.instances.get("refusal_gpt4o_objective_lenient") is not None async def test_default_tag_registers_all_current_scorers(self) -> None: """Test that default tag registers all current scorers.""" @@ -216,7 +216,7 @@ async def test_default_tag_registers_all_current_scorers(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - assert len(registry) == 28 + assert len(registry.instances) == 28 class TestScorerInitializerGetInfo: @@ -238,13 +238,13 @@ class TestScorerInitializerBestObjective: def setup_method(self) -> None: """Reset registries before each test.""" - ScorerRegistry.reset_instance() - TargetRegistry.reset_instance() + ScorerRegistry.reset_registry_singleton() + TargetRegistry.reset_registry_singleton() def teardown_method(self) -> None: """Clean up after each test.""" - ScorerRegistry.reset_instance() - TargetRegistry.reset_instance() + ScorerRegistry.reset_registry_singleton() + TargetRegistry.reset_registry_singleton() def _register_mock_target(self, *, name: str, underlying_model: str = "gpt-4o") -> OpenAIChatTarget: """Register a mock OpenAIChatTarget in the TargetRegistry.""" @@ -267,7 +267,7 @@ def _register_mock_target(self, *, name: str, underlying_model: str = "gpt-4o") }, ) registry = TargetRegistry.get_registry_singleton() - registry.register_instance(target, name=name) + registry.instances.register(target, name=name) return target @patch("pyrit.setup.initializers.components.scorers.find_objective_metrics_by_eval_hash") @@ -283,7 +283,7 @@ async def test_best_objective_tags_best_scorer(self, mock_find_metrics) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - results = registry.get_by_tag(tag=ScorerInitializerTags.BEST_OBJECTIVE) + results = registry.instances.get_by_tag(tag=ScorerInitializerTags.BEST_OBJECTIVE) assert len(results) >= 1 @patch("pyrit.setup.initializers.components.scorers.find_objective_metrics_by_eval_hash") @@ -297,9 +297,9 @@ async def test_best_objective_no_metrics_falls_back_to_category(self, mock_find_ registry = ScorerRegistry.get_registry_singleton() # Should fall back to tagging a composite scorer as best_objective - results = registry.get_by_tag(tag=ScorerInitializerTags.BEST_OBJECTIVE) + results = registry.instances.get_by_tag(tag=ScorerInitializerTags.BEST_OBJECTIVE) # Falls back to first composite if available (inverted_refusal) - composite_entries = registry.get_by_tag(tag=ScorerInitializerTags.OBJECTIVE_COMPOSITE) + composite_entries = registry.instances.get_by_tag(tag=ScorerInitializerTags.OBJECTIVE_COMPOSITE) if composite_entries: assert len(results) >= 1 else: @@ -325,7 +325,7 @@ def mock_metrics_by_hash(*, eval_hash: str, file_path=None) -> MagicMock | None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - results = registry.get_by_tag(tag=ScorerInitializerTags.BEST_OBJECTIVE) + results = registry.instances.get_by_tag(tag=ScorerInitializerTags.BEST_OBJECTIVE) assert len(results) == 1 assert ScorerInitializerTags.DEFAULT_OBJECTIVE_SCORER in results[0].tags @@ -342,11 +342,11 @@ async def test_best_objective_does_not_add_extra_entry(self, mock_find_metrics) await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - count_with_tag = len(registry) + count_with_tag = len(registry.instances) # Reset and run without metrics to get baseline count - ScorerRegistry.reset_instance() - TargetRegistry.reset_instance() + ScorerRegistry.reset_registry_singleton() + TargetRegistry.reset_registry_singleton() self._register_mock_target(name=GPT4O_TARGET) mock_find_metrics.return_value = None @@ -354,7 +354,7 @@ async def test_best_objective_does_not_add_extra_entry(self, mock_find_metrics) await init2.initialize_async() registry2 = ScorerRegistry.get_registry_singleton() - count_without_tag = len(registry2) + count_without_tag = len(registry2.instances) assert count_with_tag == count_without_tag @@ -369,12 +369,12 @@ class TestScorerInitializerCategoryTags: } def setup_method(self) -> None: - ScorerRegistry.reset_instance() - TargetRegistry.reset_instance() + ScorerRegistry.reset_registry_singleton() + TargetRegistry.reset_registry_singleton() def teardown_method(self) -> None: - ScorerRegistry.reset_instance() - TargetRegistry.reset_instance() + ScorerRegistry.reset_registry_singleton() + TargetRegistry.reset_registry_singleton() for var in self.CONTENT_SAFETY_ENV_VARS: os.environ.pop(var, None) @@ -399,7 +399,7 @@ def _register_mock_target(self, *, name: str, underlying_model: str = "gpt-4o") }, ) registry = TargetRegistry.get_registry_singleton() - registry.register_instance(target, name=name) + registry.instances.register(target, name=name) return target async def test_scale_scorers_tagged_with_scale_category(self) -> None: @@ -411,7 +411,7 @@ async def test_scale_scorers_tagged_with_scale_category(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - scale_entries = registry.get_by_tag(tag=ScorerInitializerTags.SCALE) + scale_entries = registry.instances.get_by_tag(tag=ScorerInitializerTags.SCALE) assert len(scale_entries) >= 1 async def test_acs_threshold_scorers_tagged_separately(self) -> None: @@ -423,7 +423,7 @@ async def test_acs_threshold_scorers_tagged_separately(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - threshold_entries = registry.get_by_tag(tag=ScorerInitializerTags.ACS_THRESHOLD) + threshold_entries = registry.instances.get_by_tag(tag=ScorerInitializerTags.ACS_THRESHOLD) assert len(threshold_entries) == 3 for entry in threshold_entries: assert ScorerInitializerTags.ACS in entry.tags @@ -437,7 +437,7 @@ async def test_acs_harm_scorers_tagged_separately(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - harm_entries = registry.get_by_tag(tag=ScorerInitializerTags.ACS_HARM) + harm_entries = registry.instances.get_by_tag(tag=ScorerInitializerTags.ACS_HARM) assert len(harm_entries) == 4 for entry in harm_entries: assert ScorerInitializerTags.ACS in entry.tags @@ -450,7 +450,7 @@ async def test_likert_scorers_tagged_with_likert_category(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - likert_entries = registry.get_by_tag(tag=ScorerInitializerTags.LIKERT) + likert_entries = registry.instances.get_by_tag(tag=ScorerInitializerTags.LIKERT) expected_count = sum(1 for s in LikertScalePaths if s.evaluation_files is not None) assert len(likert_entries) == expected_count @@ -463,7 +463,7 @@ async def test_task_achieved_scorers_tagged(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - task_entries = registry.get_by_tag(tag=ScorerInitializerTags.TASK_ACHIEVED) + task_entries = registry.instances.get_by_tag(tag=ScorerInitializerTags.TASK_ACHIEVED) assert len(task_entries) == 2 async def test_composite_scorers_tagged(self) -> None: @@ -475,7 +475,7 @@ async def test_composite_scorers_tagged(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - composite_entries = registry.get_by_tag(tag=ScorerInitializerTags.OBJECTIVE_COMPOSITE) + composite_entries = registry.instances.get_by_tag(tag=ScorerInitializerTags.OBJECTIVE_COMPOSITE) assert len(composite_entries) >= 1 async def test_best_refusal_tags_preferred_scorer(self) -> None: @@ -487,7 +487,7 @@ async def test_best_refusal_tags_preferred_scorer(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - best = registry.get_by_tag(tag=ScorerInitializerTags.BEST_REFUSAL) + best = registry.instances.get_by_tag(tag=ScorerInitializerTags.BEST_REFUSAL) assert len(best) == 1 assert best[0].name == "refusal_gpt5_4" @@ -499,7 +499,7 @@ async def test_best_refusal_falls_back_when_preferred_missing(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - best = registry.get_by_tag(tag=ScorerInitializerTags.BEST_REFUSAL) + best = registry.instances.get_by_tag(tag=ScorerInitializerTags.BEST_REFUSAL) assert len(best) == 1 # Should be one of the gpt4o refusal variants assert ScorerInitializerTags.REFUSAL in best[0].tags @@ -513,7 +513,7 @@ async def test_best_acs_threshold_tagged(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - best = registry.get_by_tag(tag=ScorerInitializerTags.BEST_ACS_THRESHOLD) + best = registry.instances.get_by_tag(tag=ScorerInitializerTags.BEST_ACS_THRESHOLD) assert len(best) == 1 assert best[0].name == "acs_threshold_05" @@ -523,12 +523,12 @@ class TestScorerInitializerRoundRobin: """Tests for ScorerInitializer round-robin target preference via _get_chat_target_prefer_rr.""" def setup_method(self) -> None: - ScorerRegistry.reset_instance() - TargetRegistry.reset_instance() + ScorerRegistry.reset_registry_singleton() + TargetRegistry.reset_registry_singleton() def teardown_method(self) -> None: - ScorerRegistry.reset_instance() - TargetRegistry.reset_instance() + ScorerRegistry.reset_registry_singleton() + TargetRegistry.reset_registry_singleton() def _register_mock_target(self, *, name: str, underlying_model: str = "gpt-4o") -> OpenAIChatTarget: """Register a mock OpenAIChatTarget in the TargetRegistry.""" @@ -552,7 +552,7 @@ def _register_mock_target(self, *, name: str, underlying_model: str = "gpt-4o") }, ) registry = TargetRegistry.get_registry_singleton() - registry.register_instance(target, name=name) + registry.instances.register(target, name=name) return target def _register_mock_rr_target(self, *, name: str) -> MagicMock: @@ -566,7 +566,7 @@ def _register_mock_rr_target(self, *, name: str) -> MagicMock: class_module="pyrit.prompt_target.round_robin_target", ) registry = TargetRegistry.get_registry_singleton() - registry.register_instance(rr_mock, name=name) + registry.instances.register(rr_mock, name=name) return rr_mock async def test_refusal_unsafe_uses_round_robin_when_available(self) -> None: @@ -579,7 +579,7 @@ async def test_refusal_unsafe_uses_round_robin_when_available(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - scorer = registry.get_instance_by_name("refusal_gpt4o_unsafe") + scorer = registry.instances.get("refusal_gpt4o_unsafe") assert scorer is not None assert scorer._prompt_target is rr_mock @@ -592,7 +592,7 @@ async def test_refusal_unsafe_falls_back_to_individual_when_no_rr(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - scorer = registry.get_instance_by_name("refusal_gpt4o_unsafe") + scorer = registry.instances.get("refusal_gpt4o_unsafe") assert scorer is not None assert scorer._prompt_target is individual_mock @@ -604,7 +604,7 @@ async def test_refusal_unsafe_skipped_when_target_not_available(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - assert registry.get_instance_by_name("refusal_gpt4o_unsafe") is None + assert registry.instances.get("refusal_gpt4o_unsafe") is None async def test_refusal_gpt4o_uses_round_robin_when_available(self) -> None: """Test that gpt4o-based refusal scorers use round-robin target wrapping gpt4o.""" @@ -615,7 +615,7 @@ async def test_refusal_gpt4o_uses_round_robin_when_available(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - scorer = registry.get_instance_by_name("refusal_gpt4o_objective_strict") + scorer = registry.instances.get("refusal_gpt4o_objective_strict") assert scorer is not None assert scorer._prompt_target is rr_mock @@ -627,7 +627,7 @@ async def test_refusal_gpt4o_falls_back_to_individual_when_no_rr(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - scorer = registry.get_instance_by_name("refusal_gpt4o_objective_strict") + scorer = registry.instances.get("refusal_gpt4o_objective_strict") assert scorer is not None assert scorer._prompt_target is individual_mock @@ -640,7 +640,7 @@ async def test_likert_scorers_use_round_robin_when_available(self) -> None: await init.initialize_async() registry = ScorerRegistry.get_registry_singleton() - likert_entries = registry.get_by_tag(tag=ScorerInitializerTags.LIKERT) + likert_entries = registry.instances.get_by_tag(tag=ScorerInitializerTags.LIKERT) assert len(likert_entries) > 0 for entry in likert_entries: assert entry.instance._prompt_target is rr_mock diff --git a/tests/unit/setup/test_targets_initializer.py b/tests/unit/setup/test_targets_initializer.py index 964c7789fb..f7dc4ddf78 100644 --- a/tests/unit/setup/test_targets_initializer.py +++ b/tests/unit/setup/test_targets_initializer.py @@ -32,13 +32,13 @@ class TestTargetInitializerInitialize: def setup_method(self) -> None: """Reset registry before each test.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() # Clear all target-related env vars self._clear_env_vars() def teardown_method(self) -> None: """Clean up after each test.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() self._clear_env_vars() def _clear_env_vars(self) -> None: @@ -55,7 +55,7 @@ async def test_initialize_runs_without_error_no_env_vars(self): # No targets should be registered registry = TargetRegistry.get_registry_singleton() - assert len(registry) == 0 + assert len(registry.instances) == 0 async def test_registers_target_when_env_vars_set(self): """Test that a target is registered when its env vars are set.""" @@ -67,8 +67,8 @@ async def test_registers_target_when_env_vars_set(self): await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert "platform_openai_chat" in registry - target = registry.get_instance_by_name("platform_openai_chat") + assert "platform_openai_chat" in registry.instances + target = registry.instances.get("platform_openai_chat") assert target is not None assert target._model_name == "gpt-4o" @@ -82,7 +82,7 @@ async def test_does_not_register_target_without_endpoint(self): await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert "platform_openai_chat" not in registry + assert "platform_openai_chat" not in registry.instances async def test_does_not_register_target_without_api_key(self): """Test that target is not registered if api_key env var is missing.""" @@ -94,7 +94,7 @@ async def test_does_not_register_target_without_api_key(self): await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert "platform_openai_chat" not in registry + assert "platform_openai_chat" not in registry.instances async def test_registers_multiple_targets(self): """Test that multiple targets are registered when their env vars are set.""" @@ -112,9 +112,9 @@ async def test_registers_multiple_targets(self): await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert len(registry) == 2 - assert "platform_openai_chat" in registry - assert "openai_image_platform" in registry + assert len(registry.instances) == 2 + assert "platform_openai_chat" in registry.instances + assert "openai_image_platform" in registry.instances async def test_registers_azure_content_safety_without_model(self): """Test that PromptShieldTarget is registered without model_name (it doesn't use one).""" @@ -127,7 +127,7 @@ async def test_registers_azure_content_safety_without_model(self): await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert "azure_content_safety" in registry + assert "azure_content_safety" in registry.instances async def test_underlying_model_passed_when_set(self): """Test that underlying_model is passed to target when env var is set.""" @@ -142,7 +142,7 @@ async def test_underlying_model_passed_when_set(self): await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - target = registry.get_instance_by_name("azure_openai_gpt4o") + target = registry.instances.get("azure_openai_gpt4o") assert target is not None assert target._model_name == "my-deployment-name" assert target._underlying_model == "gpt-4o" @@ -156,8 +156,8 @@ async def test_registers_ollama_without_api_key(self): await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert "ollama" in registry - target = registry.get_instance_by_name("ollama") + assert "ollama" in registry.instances + target = registry.instances.get("ollama") assert target is not None assert target._model_name == "llama2" @@ -176,7 +176,7 @@ def mock_token_provider() -> str: await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - target = registry.get_instance_by_name("azure_openai_gpt4o") + target = registry.instances.get("azure_openai_gpt4o") assert target is not None # The token provider gets wrapped by _ensure_async_token_provider, so just verify it's callable assert callable(target._api_key) # type: ignore[ty:unresolved-attribute] @@ -258,11 +258,11 @@ class TestTargetInitializerTags: def setup_method(self) -> None: """Reset registry before each test.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() def teardown_method(self) -> None: """Clean up after each test.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() async def test_no_tags_registers_default_only(self) -> None: """Test that no tags registers only default targets (not scorer variants).""" @@ -275,9 +275,9 @@ async def test_no_tags_registers_default_only(self) -> None: registry = TargetRegistry.get_registry_singleton() # Default targets should be registered (including temp9), scorer-only should not - assert registry.get_instance_by_name("azure_openai_gpt4o") is not None - assert registry.get_instance_by_name("azure_openai_gpt4o_temp9") is not None - assert registry.get_instance_by_name("azure_openai_gpt4o_temp0") is None + assert registry.instances.get("azure_openai_gpt4o") is not None + assert registry.instances.get("azure_openai_gpt4o_temp9") is not None + assert registry.instances.get("azure_openai_gpt4o_temp0") is None # Clean up del os.environ["AZURE_OPENAI_GPT4O_ENDPOINT"] @@ -295,9 +295,9 @@ async def test_default_tag_excludes_scorer_targets(self) -> None: await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert registry.get_instance_by_name("azure_openai_gpt4o") is not None - assert registry.get_instance_by_name("azure_openai_gpt4o_temp9") is not None - assert registry.get_instance_by_name("azure_openai_gpt4o_temp0") is None + assert registry.instances.get("azure_openai_gpt4o") is not None + assert registry.instances.get("azure_openai_gpt4o_temp9") is not None + assert registry.instances.get("azure_openai_gpt4o_temp0") is None # Clean up del os.environ["AZURE_OPENAI_GPT4O_ENDPOINT"] @@ -315,9 +315,9 @@ async def test_scorer_tag_only_registers_scorer_targets(self) -> None: await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert registry.get_instance_by_name("azure_openai_gpt4o") is None - assert registry.get_instance_by_name("azure_openai_gpt4o_temp9") is None - assert registry.get_instance_by_name("azure_openai_gpt4o_temp0") is not None + assert registry.instances.get("azure_openai_gpt4o") is None + assert registry.instances.get("azure_openai_gpt4o_temp9") is None + assert registry.instances.get("azure_openai_gpt4o_temp0") is not None # Clean up del os.environ["AZURE_OPENAI_GPT4O_ENDPOINT"] @@ -335,8 +335,8 @@ async def test_multiple_tags_registers_matching(self) -> None: await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert registry.get_instance_by_name("azure_openai_gpt4o") is not None - assert registry.get_instance_by_name("azure_openai_gpt4o_temp9") is not None + assert registry.instances.get("azure_openai_gpt4o") is not None + assert registry.instances.get("azure_openai_gpt4o_temp9") is not None # Clean up del os.environ["AZURE_OPENAI_GPT4O_ENDPOINT"] @@ -354,8 +354,8 @@ async def test_all_tag_registers_all_targets(self) -> None: await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert registry.get_instance_by_name("azure_openai_gpt4o") is not None - assert registry.get_instance_by_name("azure_openai_gpt4o_temp9") is not None + assert registry.instances.get("azure_openai_gpt4o") is not None + assert registry.instances.get("azure_openai_gpt4o_temp9") is not None # Clean up del os.environ["AZURE_OPENAI_GPT4O_ENDPOINT"] @@ -369,11 +369,11 @@ class TestTargetInitializerDefaultObjectiveTarget: def setup_method(self) -> None: """Reset registry before each test.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() def teardown_method(self) -> None: """Clean up after each test.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() for var in ["OPENAI_CHAT_ENDPOINT", "OPENAI_CHAT_KEY", "OPENAI_CHAT_MODEL"]: os.environ.pop(var, None) @@ -389,9 +389,9 @@ async def test_openai_chat_registered_with_default_tag(self) -> None: await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert "openai_chat" in registry + assert "openai_chat" in registry.instances - entries = registry.get_by_tag(tag=TargetInitializerTags.DEFAULT_OBJECTIVE_TARGET) + entries = registry.instances.get_by_tag(tag=TargetInitializerTags.DEFAULT_OBJECTIVE_TARGET) assert len(entries) == 1 assert entries[0].name == "openai_chat" @@ -403,7 +403,7 @@ async def test_no_default_tag_when_env_vars_missing(self) -> None: await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - entries = registry.get_by_tag(tag=TargetInitializerTags.DEFAULT_OBJECTIVE_TARGET) + entries = registry.instances.get_by_tag(tag=TargetInitializerTags.DEFAULT_OBJECTIVE_TARGET) assert len(entries) == 0 async def test_openai_chat_config_has_default_objective_target_flag(self) -> None: @@ -427,11 +427,11 @@ class TestTargetInitializerConfigTagPropagation: def setup_method(self) -> None: """Reset registry before each test.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() def teardown_method(self) -> None: """Clean up after each test.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() for var in [ "OBJECTIVE_SCORER_CHAT_ENDPOINT", "OBJECTIVE_SCORER_CHAT_KEY", @@ -457,14 +457,14 @@ async def test_register_target_propagates_config_tags(self) -> None: await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert "objective_scorer_chat" in registry + assert "objective_scorer_chat" in registry.instances - scorer_entries = registry.get_by_tag(tag=TargetInitializerTags.SCORER) + scorer_entries = registry.instances.get_by_tag(tag=TargetInitializerTags.SCORER) assert any(entry.name == "objective_scorer_chat" for entry in scorer_entries), ( "objective_scorer_chat should be discoverable by the SCORER tag after F1c" ) - default_entries = registry.get_by_tag(tag=TargetInitializerTags.DEFAULT) + default_entries = registry.instances.get_by_tag(tag=TargetInitializerTags.DEFAULT) assert any(entry.name == "objective_scorer_chat" for entry in default_entries), ( "objective_scorer_chat declares both DEFAULT and SCORER tags; both must propagate" ) @@ -491,8 +491,8 @@ async def test_register_target_no_tags_in_config_no_extra_add_tags(self) -> None init = TargetInitializer() init._register_target(config) - mock_registry.register_instance.assert_called_once() - mock_registry.add_tags.assert_not_called() + mock_registry.instances.register.assert_called_once() + mock_registry.instances.add_tags.assert_not_called() finally: os.environ.pop("EMPTY_TAGS_ENDPOINT", None) @@ -511,11 +511,11 @@ async def test_register_target_default_objective_tag_still_applied(self) -> None await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - default_objective_entries = registry.get_by_tag(tag=TargetInitializerTags.DEFAULT_OBJECTIVE_TARGET) + default_objective_entries = registry.instances.get_by_tag(tag=TargetInitializerTags.DEFAULT_OBJECTIVE_TARGET) assert len(default_objective_entries) == 1 assert default_objective_entries[0].name == "openai_chat" - default_entries = registry.get_by_tag(tag=TargetInitializerTags.DEFAULT) + default_entries = registry.instances.get_by_tag(tag=TargetInitializerTags.DEFAULT) assert any(entry.name == "openai_chat" for entry in default_entries), ( "openai_chat's config.tags=[DEFAULT] must propagate even when default_objective_target=True" ) @@ -534,12 +534,12 @@ class TestTargetInitializerAdversarialChatVariants: def setup_method(self) -> None: """Reset registry and clear variant env vars.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() self._clear_variant_env_vars() def teardown_method(self) -> None: """Reset registry and clear variant env vars.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() self._clear_variant_env_vars() @staticmethod @@ -565,9 +565,9 @@ async def test_variant_registers_with_default_tag(self, registry_name: str, env_ await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert registry_name in registry + assert registry_name in registry.instances - default_entries = registry.get_by_tag(tag=TargetInitializerTags.DEFAULT) + default_entries = registry.instances.get_by_tag(tag=TargetInitializerTags.DEFAULT) assert any(entry.name == registry_name for entry in default_entries) @pytest.mark.parametrize(("registry_name", "env_prefix"), ADVERSARIAL_CHAT_VARIANTS) @@ -577,7 +577,7 @@ async def test_variant_skips_when_env_vars_missing(self, registry_name: str, env await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert registry_name not in registry + assert registry_name not in registry.instances @pytest.mark.parametrize(("registry_name", "env_prefix"), ADVERSARIAL_CHAT_VARIANTS) async def test_variant_skips_when_model_env_var_missing( @@ -595,7 +595,7 @@ async def test_variant_skips_when_model_env_var_missing( await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert registry_name not in registry + assert registry_name not in registry.instances captured_messages = [r.message for r in caplog.records] assert any(f"{env_prefix}_MODEL" in m for m in captured_messages), ( @@ -623,12 +623,12 @@ async def test_double_initialize_async_is_idempotent(self) -> None: init = TargetInitializer() await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - first_names = sorted(registry.get_names()) - first_default_count = len(registry.get_by_tag(tag=TargetInitializerTags.DEFAULT)) + first_names = sorted(registry.instances.get_names()) + first_default_count = len(registry.instances.get_by_tag(tag=TargetInitializerTags.DEFAULT)) await init.initialize_async() - second_names = sorted(registry.get_names()) - second_default_count = len(registry.get_by_tag(tag=TargetInitializerTags.DEFAULT)) + second_names = sorted(registry.instances.get_names()) + second_default_count = len(registry.instances.get_by_tag(tag=TargetInitializerTags.DEFAULT)) assert first_names == second_names assert first_default_count == second_default_count @@ -668,12 +668,12 @@ class TestTargetInitializerAutoGroup: def setup_method(self) -> None: """Reset registry and clear env vars before each test.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() self._clear_env_vars() def teardown_method(self) -> None: """Clean up after each test.""" - TargetRegistry.reset_instance() + TargetRegistry.reset_registry_singleton() self._clear_env_vars() def _clear_env_vars(self) -> None: @@ -695,12 +695,14 @@ async def test_auto_groups_targets_with_same_underlying_model(self) -> None: # Find the auto-generated round-robin by checking for RoundRobinTarget instances rr_names = [ - name for name in registry.get_names() if isinstance(registry.get_instance_by_name(name), RoundRobinTarget) + name + for name in registry.instances.get_names() + if isinstance(registry.instances.get(name), RoundRobinTarget) ] assert len(rr_names) >= 1, "Expected at least one auto-grouped round-robin target" # The gpt-4o round-robin should contain both gpt4o targets - rr = registry.get_instance_by_name("OpenAIChatTarget_gpt-4o_rr") + rr = registry.instances.get("OpenAIChatTarget_gpt-4o_rr") assert rr is not None assert isinstance(rr, RoundRobinTarget) @@ -713,8 +715,8 @@ async def test_individual_targets_still_accessible_after_auto_group(self) -> Non await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert registry.get_instance_by_name("azure_openai_gpt4o") is not None - assert registry.get_instance_by_name("azure_openai_gpt4o2") is not None + assert registry.instances.get("azure_openai_gpt4o") is not None + assert registry.instances.get("azure_openai_gpt4o2") is not None async def test_no_round_robin_when_single_target(self) -> None: """Test that no round-robin is created when only one target has a given model.""" @@ -724,9 +726,9 @@ async def test_no_round_robin_when_single_target(self) -> None: await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert registry.get_instance_by_name("azure_openai_gpt4o") is not None + assert registry.instances.get("azure_openai_gpt4o") is not None # No round-robin should exist - rr = registry.get_instance_by_name("OpenAIChatTarget_gpt-4o_rr") + rr = registry.instances.get("OpenAIChatTarget_gpt-4o_rr") assert rr is None async def test_no_round_robin_when_no_targets(self) -> None: @@ -735,7 +737,7 @@ async def test_no_round_robin_when_no_targets(self) -> None: await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert len(registry) == 0 + assert len(registry.instances) == 0 async def test_different_temperatures_not_grouped(self) -> None: """Test that targets with different temperatures are NOT grouped together.""" @@ -747,8 +749,8 @@ async def test_different_temperatures_not_grouped(self) -> None: await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - base = registry.get_instance_by_name("azure_openai_gpt4o") - temp9 = registry.get_instance_by_name("azure_openai_gpt4o_temp9") + base = registry.instances.get("azure_openai_gpt4o") + temp9 = registry.instances.get("azure_openai_gpt4o_temp9") # Both should exist but have different behavioral keys assert base is not None @@ -767,8 +769,8 @@ async def test_different_target_classes_not_grouped(self) -> None: await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - chat_target = registry.get_instance_by_name("azure_openai_gpt4o") - response_target = registry.get_instance_by_name("azure_openai_responses") + chat_target = registry.instances.get("azure_openai_gpt4o") + response_target = registry.instances.get("azure_openai_responses") assert chat_target is not None assert response_target is not None @@ -796,11 +798,13 @@ async def test_auto_group_disabled_when_false(self) -> None: registry = TargetRegistry.get_registry_singleton() # Individual targets should exist - assert registry.get_instance_by_name("azure_openai_gpt4o") is not None - assert registry.get_instance_by_name("azure_openai_gpt4o2") is not None + assert registry.instances.get("azure_openai_gpt4o") is not None + assert registry.instances.get("azure_openai_gpt4o2") is not None # But no round-robin should be created rr_targets = [ - name for name in registry.get_names() if isinstance(registry.get_instance_by_name(name), RoundRobinTarget) + name + for name in registry.instances.get_names() + if isinstance(registry.instances.get(name), RoundRobinTarget) ] assert len(rr_targets) == 0 @@ -816,7 +820,7 @@ async def test_auto_group_three_targets_same_model(self) -> None: await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - rr = registry.get_instance_by_name("OpenAIChatTarget_gpt-4o_rr") + rr = registry.instances.get("OpenAIChatTarget_gpt-4o_rr") assert rr is not None assert isinstance(rr, RoundRobinTarget) # Should have 3 inner targets (gpt4o, gpt4o2, unsafe_chat) @@ -849,16 +853,16 @@ async def test_auto_group_deduplicates_identical_targets(self) -> None: await init.initialize_async() # Register the duplicate after init so it's in the registry with a distinct name. - registry.register_instance(dup_target, name="unsafe1_duplicate") + registry.instances.register(dup_target, name="unsafe1_duplicate") init._registered_names.append("unsafe1_duplicate") # Re-run auto-grouping (clear existing RR first) - existing_rr = registry.get_instance_by_name("OpenAIChatTarget_gpt-4o_rr") + existing_rr = registry.instances.get("OpenAIChatTarget_gpt-4o_rr") if existing_rr: - registry._instances.pop("OpenAIChatTarget_gpt-4o_rr", None) + registry.instances._registry_items.pop("OpenAIChatTarget_gpt-4o_rr", None) init._auto_group_targets() - rr = registry.get_instance_by_name("OpenAIChatTarget_gpt-4o_rr") + rr = registry.instances.get("OpenAIChatTarget_gpt-4o_rr") assert rr is not None assert isinstance(rr, RoundRobinTarget) # Should have 3, not 4 — the duplicate should be deduplicated