Skip to content

Commit 43a38e5

Browse files
rlundeen2Copilot
andauthored
FEAT: Always recompute ComponentIdentifier hashes (#2050)
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 8376a81 commit 43a38e5

9 files changed

Lines changed: 285 additions & 329 deletions

File tree

pyrit/memory/memory_models.py

Lines changed: 59 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
Conversation,
4343
ConversationReference,
4444
ConversationType,
45+
EvaluationIdentifier,
4546
MessagePiece,
4647
PromptDataType,
4748
ScenarioIdentifier,
@@ -62,49 +63,28 @@
6263
# Default pyrit_version for database records created before version tracking was added
6364
LEGACY_PYRIT_VERSION = "<0.10.0"
6465

65-
# Maximum length for string values in ComponentIdentifier.model_dump() when storing to the database.
66-
# Longer values are truncated with a "..." suffix.
67-
MAX_IDENTIFIER_VALUE_LENGTH: int = 80
6866

69-
70-
def _dump_identifier(identifier: ComponentIdentifier | None) -> dict[str, Any] | None:
71-
"""
72-
Serialize a ``ComponentIdentifier`` to a dict for JSON storage, truncating long values.
73-
74-
Args:
75-
identifier (ComponentIdentifier | None): The identifier to serialize, or None.
76-
77-
Returns:
78-
dict[str, Any] | None: The serialized identifier, or None if ``identifier`` is falsy.
79-
"""
80-
if not identifier:
81-
return None
82-
return identifier.model_dump(context={"max_value_length": MAX_IDENTIFIER_VALUE_LENGTH})
83-
84-
85-
def _dump_identifiers(identifiers: list[ComponentIdentifier]) -> list[dict[str, Any]]:
86-
"""
87-
Serialize a list of ``ComponentIdentifier`` objects for JSON storage.
88-
89-
Args:
90-
identifiers (list[ComponentIdentifier]): The identifiers to serialize.
91-
92-
Returns:
93-
list[dict[str, Any]]: The serialized identifiers in order.
94-
"""
95-
return [
96-
identifier.model_dump(context={"max_value_length": MAX_IDENTIFIER_VALUE_LENGTH}) for identifier in identifiers
97-
]
98-
99-
100-
def _load_identifier(stored: dict[str, Any] | None, *, pyrit_version: str | None = None) -> ComponentIdentifier | None:
67+
def _load_identifier(
68+
stored: dict[str, Any] | None,
69+
*,
70+
pyrit_version: str | None = None,
71+
eval_identifier_cls: type[EvaluationIdentifier] | None = None,
72+
) -> ComponentIdentifier | None:
10173
"""
10274
Reconstruct a ``ComponentIdentifier`` from its stored dict representation.
10375
76+
The content hash is recomputed on validation (never trusted from storage).
77+
When ``eval_identifier_cls`` is provided, the ``eval_hash`` is likewise
78+
recomputed from the (full) stored params and re-stamped onto the identifier,
79+
so the stored ``eval_hash`` value is never trusted on reload.
80+
10481
Args:
10582
stored (dict[str, Any] | None): The stored identifier dict, or None.
10683
pyrit_version (str | None): If provided, injected as the identifier's ``pyrit_version``
10784
so the reconstructed object reflects the version that created the row.
85+
eval_identifier_cls (type[EvaluationIdentifier] | None): If provided, the
86+
``EvaluationIdentifier`` subclass used to recompute and re-stamp the
87+
identifier's ``eval_hash`` on reload.
10888
10989
Returns:
11090
ComponentIdentifier | None: The reconstructed identifier, or None if ``stored`` is falsy.
@@ -113,7 +93,10 @@ def _load_identifier(stored: dict[str, Any] | None, *, pyrit_version: str | None
11393
return None
11494
if pyrit_version is not None:
11595
stored = {**stored, "pyrit_version": pyrit_version}
116-
return ComponentIdentifier.model_validate(stored)
96+
identifier = ComponentIdentifier.model_validate(stored)
97+
if eval_identifier_cls is not None:
98+
identifier = identifier.with_eval_hash(eval_identifier_cls(identifier).eval_hash)
99+
return identifier
117100

118101

119102
def _load_identifiers(
@@ -310,7 +293,7 @@ def __init__(self, *, entry: MessagePiece) -> None:
310293
self.timestamp = entry.timestamp
311294
self.labels = entry.labels
312295
self.prompt_metadata = entry.prompt_metadata
313-
self.converter_identifiers = _dump_identifiers(entry.converter_identifiers)
296+
self.converter_identifiers = [identifier.model_dump() for identifier in entry.converter_identifiers]
314297

315298
self.original_value = entry.original_value
316299
self.original_value_data_type = entry.original_value_data_type
@@ -399,7 +382,7 @@ def __init__(self, *, conversation: Conversation) -> None:
399382
conversation (Conversation): The conversation metadata to persist.
400383
"""
401384
self.conversation_id = conversation.conversation_id
402-
self.target_identifier = _dump_identifier(conversation.target_identifier)
385+
self.target_identifier = conversation.target_identifier.model_dump() if conversation.target_identifier else None
403386
self.pyrit_version = pyrit.__version__
404387

405388
def get_conversation(self) -> Conversation:
@@ -484,12 +467,13 @@ def __init__(self, *, entry: Score) -> None:
484467
self.score_rationale = entry.score_rationale
485468
self.score_metadata = entry.score_metadata or {}
486469
normalized_scorer = entry.scorer_class_identifier
487-
# Ensure eval_hash is set before truncation so it survives the DB round-trip
488-
if normalized_scorer is not None and normalized_scorer.eval_hash is None:
470+
# Always recompute eval_hash before dumping so the stored JSON carries the
471+
# freshly computed value for DB-level filtering (never a value from storage).
472+
if normalized_scorer is not None:
489473
normalized_scorer = normalized_scorer.with_eval_hash(
490474
ScorerEvaluationIdentifier(normalized_scorer).eval_hash
491475
)
492-
self.scorer_class_identifier = _dump_identifier(normalized_scorer) or {}
476+
self.scorer_class_identifier = normalized_scorer.model_dump() if normalized_scorer else {}
493477
self.prompt_request_response_id = entry.message_piece_id if entry.message_piece_id else None
494478
self.timestamp = entry.timestamp
495479
# Store in both columns for backward compatibility
@@ -505,9 +489,14 @@ def get_score(self) -> Score:
505489
Returns:
506490
Score: The reconstructed score object with all its data.
507491
"""
508-
# Convert dict back to ComponentIdentifier with the stored pyrit_version
492+
# Convert dict back to ComponentIdentifier with the stored pyrit_version;
493+
# eval_hash is recomputed on reload via ScorerEvaluationIdentifier.
509494
stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION
510-
scorer_identifier = _load_identifier(self.scorer_class_identifier, pyrit_version=stored_version)
495+
scorer_identifier = _load_identifier(
496+
self.scorer_class_identifier,
497+
pyrit_version=stored_version,
498+
eval_identifier_cls=ScorerEvaluationIdentifier,
499+
)
511500
return Score(
512501
id=self.id,
513502
score_value=self.score_value,
@@ -933,12 +922,15 @@ def __init__(self, *, entry: AttackResult) -> None:
933922
self.id = uuid.UUID(entry.attack_result_id)
934923
self.conversation_id = entry.conversation_id
935924
self.objective = entry.objective
936-
# Ensure eval_hash is set before truncation so it survives the DB round-trip
937-
if entry.atomic_attack_identifier and entry.atomic_attack_identifier.eval_hash is None:
925+
# Always recompute eval_hash before dumping so the stored JSON carries the
926+
# freshly computed value for DB-level filtering (never a value from storage).
927+
if entry.atomic_attack_identifier:
938928
entry.atomic_attack_identifier = entry.atomic_attack_identifier.with_eval_hash(
939929
AtomicAttackEvaluationIdentifier(entry.atomic_attack_identifier).eval_hash
940930
)
941-
self.atomic_attack_identifier = _dump_identifier(entry.atomic_attack_identifier)
931+
self.atomic_attack_identifier = (
932+
entry.atomic_attack_identifier.model_dump() if entry.atomic_attack_identifier else None
933+
)
942934
self.objective_sha256 = to_sha256(entry.objective)
943935

944936
# Use helper method for UUID conversions
@@ -1055,7 +1047,11 @@ def get_attack_result(self) -> AttackResult:
10551047
)
10561048
)
10571049

1058-
atomic_id = _load_identifier(self.atomic_attack_identifier)
1050+
# eval_hash is recomputed on reload via AtomicAttackEvaluationIdentifier.
1051+
atomic_id = _load_identifier(
1052+
self.atomic_attack_identifier,
1053+
eval_identifier_cls=AtomicAttackEvaluationIdentifier,
1054+
)
10591055

10601056
# Deserialize retry events from JSON
10611057
retry_events = []
@@ -1172,13 +1168,18 @@ def __init__(self, *, entry: ScenarioResult) -> None:
11721168
self.pyrit_version = entry.scenario_identifier.pyrit_version
11731169
self.scenario_init_data = entry.scenario_identifier.init_data
11741170
# Convert ComponentIdentifier to dict for JSON storage
1175-
self.objective_target_identifier = _dump_identifier(entry.objective_target_identifier) # type: ignore[ty:invalid-assignment]
1176-
# Ensure eval_hash is set before truncation so it survives the DB round-trip.
1177-
if entry.objective_scorer_identifier and entry.objective_scorer_identifier.eval_hash is None:
1171+
self.objective_target_identifier = ( # type: ignore[ty:invalid-assignment]
1172+
entry.objective_target_identifier.model_dump() if entry.objective_target_identifier else None
1173+
)
1174+
# Always recompute eval_hash before dumping so the stored JSON carries the
1175+
# freshly computed value for DB-level filtering (never a value from storage).
1176+
if entry.objective_scorer_identifier:
11781177
entry.objective_scorer_identifier = entry.objective_scorer_identifier.with_eval_hash(
11791178
ScorerEvaluationIdentifier(entry.objective_scorer_identifier).eval_hash
11801179
)
1181-
self.objective_scorer_identifier = _dump_identifier(entry.objective_scorer_identifier)
1180+
self.objective_scorer_identifier = (
1181+
entry.objective_scorer_identifier.model_dump() if entry.objective_scorer_identifier else None
1182+
)
11821183
self.scenario_run_state = entry.scenario_run_state.value
11831184
self.labels = entry.labels
11841185
self.number_tries = entry.number_tries
@@ -1224,8 +1225,13 @@ def get_scenario_result(self) -> ScenarioResult:
12241225
# Return empty attack_results - will be populated by memory_interface
12251226
attack_results: dict[str, list[AttackResult]] = {}
12261227

1227-
# Convert dict back to ComponentIdentifier with the stored pyrit_version
1228-
scorer_identifier = _load_identifier(self.objective_scorer_identifier, pyrit_version=stored_version)
1228+
# Convert dict back to ComponentIdentifier with the stored pyrit_version;
1229+
# eval_hash is recomputed on reload via ScorerEvaluationIdentifier.
1230+
scorer_identifier = _load_identifier(
1231+
self.objective_scorer_identifier,
1232+
pyrit_version=stored_version,
1233+
eval_identifier_cls=ScorerEvaluationIdentifier,
1234+
)
12291235

12301236
# Convert dict back to ComponentIdentifier for reconstruction
12311237
target_identifier = _load_identifier(self.objective_target_identifier)

0 commit comments

Comments
 (0)