diff --git a/garak/generators/replicate.py b/garak/generators/replicate.py index 72c233007..c9ffdb737 100644 --- a/garak/generators/replicate.py +++ b/garak/generators/replicate.py @@ -119,7 +119,7 @@ def _call_model( raise IOError( "Replicate endpoint didn't generate an Iterable[str]-type response. Make sure the endpoint is active." ) from exc - return [Message(r) for r in response] + return [Message(response)] DEFAULT_CLASS = "ReplicateGenerator" diff --git a/tests/generators/test_replicate.py b/tests/generators/test_replicate.py new file mode 100644 index 000000000..a7159d989 --- /dev/null +++ b/tests/generators/test_replicate.py @@ -0,0 +1,27 @@ +from unittest.mock import Mock + +from garak.attempt import Conversation, Message, Turn +from garak.generators.replicate import InferenceEndpoint + + +def test_replicate_inference_endpoint_returns_single_message(): + generator = InferenceEndpoint.__new__(InferenceEndpoint) + generator.client = Mock() + generator.name = "owner/private-endpoint" + generator.max_tokens = 20 + generator.temperature = 0.7 + generator.top_p = 0.9 + generator.repetition_penalty = 1.1 + + prediction = Mock() + prediction.output = ["hello", " world"] + deployment = Mock() + deployment.predictions.create.return_value = prediction + generator.client.deployments.get.return_value = deployment + + conv = Conversation([Turn("user", Message("test prompt"))]) + + output = generator._call_model(conv) + + assert output == [Message("hello world")] + prediction.wait.assert_called_once_with()