Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyrit/message_normalizer/json_schema_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _adapt_message(self, *, message: Message) -> Message:
if not changed:
return message

return Message(new_pieces)
return Message(message_pieces=new_pieces)

def _adapt_piece(self, *, piece: MessagePiece) -> MessagePiece:
"""
Expand Down
20 changes: 10 additions & 10 deletions tests/unit/message_normalizer/test_json_schema_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class TestJsonSchemaNormalizer:
async def test_text_piece_gets_schema_appended_to_converted_value(self, normalizer: JsonSchemaNormalizer) -> None:
schema = {"type": "object", "properties": {"answer": {"type": "string"}}}
piece = _text_piece(value="Answer the question.", metadata={JSON_SCHEMA_METADATA_KEY: schema})
message = Message([piece])
message = Message(message_pieces=[piece])

result = await normalizer.normalize_async([message])
out_piece = result[0].message_pieces[0]
Expand All @@ -66,7 +66,7 @@ async def test_text_piece_preserves_other_metadata(self, normalizer: JsonSchemaN
"other": 7,
},
)
result = await normalizer.normalize_async([Message([piece])])
result = await normalizer.normalize_async([Message(message_pieces=[piece])])
new_metadata = result[0].message_pieces[0].prompt_metadata
assert JSON_SCHEMA_METADATA_KEY not in new_metadata
assert new_metadata == {"response_format": "json", "other": 7}
Expand All @@ -76,7 +76,7 @@ async def test_non_text_piece_only_strips_key(self, normalizer: JsonSchemaNormal
piece = _image_piece(value="fake.jpg", metadata={JSON_SCHEMA_METADATA_KEY: schema, "extra": "stay"})
original_converted_value = piece.converted_value

result = await normalizer.normalize_async([Message([piece])])
result = await normalizer.normalize_async([Message(message_pieces=[piece])])
out_piece = result[0].message_pieces[0]

assert JSON_SCHEMA_METADATA_KEY not in out_piece.prompt_metadata
Expand All @@ -87,7 +87,7 @@ async def test_non_text_piece_only_strips_key(self, normalizer: JsonSchemaNormal

async def test_no_schema_is_noop(self, normalizer: JsonSchemaNormalizer) -> None:
piece = _text_piece(value="just say hi", metadata={"unrelated": True})
message = Message([piece])
message = Message(message_pieces=[piece])

result = await normalizer.normalize_async([message])

Expand All @@ -98,7 +98,7 @@ async def test_input_pieces_not_mutated(self, normalizer: JsonSchemaNormalizer)
schema = {"type": "object"}
piece = _text_piece(value="hi", metadata={JSON_SCHEMA_METADATA_KEY: schema})

await normalizer.normalize_async([Message([piece])])
await normalizer.normalize_async([Message(message_pieces=[piece])])

# The original piece still carries the schema and its unchanged text.
assert piece.prompt_metadata == {JSON_SCHEMA_METADATA_KEY: schema}
Expand All @@ -117,7 +117,7 @@ async def test_mixed_pieces_in_message_each_handled(self, normalizer: JsonSchema
)
no_schema_piece = _text_piece(value="z", metadata={"foo": "bar"}, conversation_id=conversation_id)

result = await normalizer.normalize_async([Message([text_piece, image_piece, no_schema_piece])])
result = await normalizer.normalize_async([Message(message_pieces=[text_piece, image_piece, no_schema_piece])])
out_pieces = result[0].message_pieces

assert JSON_SCHEMA_METADATA_KEY not in out_pieces[0].prompt_metadata
Expand All @@ -133,8 +133,8 @@ async def test_mixed_pieces_in_message_each_handled(self, normalizer: JsonSchema

async def test_multiple_messages(self, normalizer: JsonSchemaNormalizer) -> None:
schema = {"type": "object"}
msg_with_schema = Message([_text_piece(value="a", metadata={JSON_SCHEMA_METADATA_KEY: schema})])
msg_without_schema = Message([_text_piece(value="b", metadata={})])
msg_with_schema = Message(message_pieces=[_text_piece(value="a", metadata={JSON_SCHEMA_METADATA_KEY: schema})])
msg_without_schema = Message(message_pieces=[_text_piece(value="b", metadata={})])

result = await normalizer.normalize_async([msg_with_schema, msg_without_schema])
assert "### Response format" in result[0].message_pieces[0].converted_value
Expand All @@ -155,7 +155,7 @@ async def test_appended_text_lists_schema_keys(self, normalizer: JsonSchemaNorma
}
piece = _text_piece(value="prompt", metadata={JSON_SCHEMA_METADATA_KEY: schema})

result = await normalizer.normalize_async([Message([piece])])
result = await normalizer.normalize_async([Message(message_pieces=[piece])])
appended = result[0].message_pieces[0].converted_value

# Sanity-check that the rendered text actually surfaces schema field names.
Expand All @@ -170,7 +170,7 @@ async def test_custom_template_is_used(self) -> None:
schema = {"type": "object"}
piece = _text_piece(value="hi", metadata={JSON_SCHEMA_METADATA_KEY: schema})

result = await normalizer.normalize_async([Message([piece])])
result = await normalizer.normalize_async([Message(message_pieces=[piece])])
out_value = result[0].message_pieces[0].converted_value

assert "<<SCHEMA START>>" in out_value
Expand Down