Skip to content

Commit 8ff80f0

Browse files
committed
feat: implement OpenAIResponsesGenerator
Signed-off-by: ABeltramo <beltramo.ale@gmail.com>
1 parent 2e9b3f5 commit 8ff80f0

2 files changed

Lines changed: 563 additions & 0 deletions

File tree

garak/generators/openai.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,4 +431,171 @@ class OpenAIReasoningGenerator(OpenAIGenerator):
431431
}
432432

433433

434+
class OpenAIResponsesGenerator(Generator):
435+
"""Generator using the OpenAI Responses API with server-side tool orchestration.
436+
437+
Supports MCP servers and function tools via the ``tools`` parameter.
438+
Unlike the chat-completions generators, the Responses API runs the full
439+
agentic loop (tool calls → execution → follow-up) on the server side,
440+
returning only the final text to garak.
441+
"""
442+
443+
ENV_VAR = "OPENAI_API_KEY"
444+
active = True
445+
generator_family_name = "OpenAIResponses"
446+
supports_multiple_generations = False
447+
448+
DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
449+
"uri": None,
450+
"instructions": None,
451+
"tools": [],
452+
"max_output_tokens": 1024,
453+
"extra_params": {},
454+
# response.output item types to collect into the final text.
455+
# "message" captures standard assistant text; "reasoning" captures
456+
# the model's reasoning summary (ResponseReasoningItem).
457+
"output_types": ["message"],
458+
}
459+
460+
_unsafe_attributes = ["client"]
461+
462+
def _load_unsafe(self):
463+
kwargs = {"api_key": getattr(self, "api_key", None)}
464+
if getattr(self, "uri", None):
465+
kwargs["base_url"] = self.uri
466+
self.client = openai.OpenAI(**kwargs)
467+
468+
def __init__(self, name="", config_root=_config):
469+
self.name = name
470+
self._load_config(config_root)
471+
self.fullname = f"{self.generator_family_name} {self.name}"
472+
self.key_env_var = self.ENV_VAR
473+
self._load_unsafe()
474+
super().__init__(self.name, config_root=config_root)
475+
476+
@staticmethod
477+
def _build_input(prompt: Union[Conversation, str]):
478+
"""Convert a Conversation (or raw string) to the Responses API input format.
479+
480+
System turns are excluded — callers should promote them to ``instructions``
481+
via :meth:`_extract_system_prompt` before calling this method.
482+
"""
483+
if isinstance(prompt, str):
484+
return prompt
485+
non_system = [t for t in prompt.turns if t.role != "system"]
486+
if len(non_system) == 1:
487+
return non_system[0].content.text
488+
items = []
489+
for turn in non_system:
490+
content_type = "output_text" if turn.role == "assistant" else "input_text"
491+
items.append(
492+
{
493+
"type": "message",
494+
"role": turn.role,
495+
"content": [{"type": content_type, "text": turn.content.text}],
496+
}
497+
)
498+
return items
499+
500+
@staticmethod
501+
def _extract_system_prompt(prompt: Union[Conversation, str]) -> Union[str, None]:
502+
"""Return system turn text from a Conversation as a single string, or None.
503+
504+
Multiple system turns are joined with a newline so no content is lost.
505+
"""
506+
if not isinstance(prompt, Conversation):
507+
return None
508+
system_texts = [t.content.text for t in prompt.turns if t.role == "system"]
509+
return "\n".join(system_texts) if system_texts else None
510+
511+
@backoff.on_exception(
512+
backoff.fibo,
513+
(
514+
openai.RateLimitError,
515+
openai.InternalServerError,
516+
openai.APITimeoutError,
517+
openai.APIConnectionError,
518+
garak.exception.GeneratorBackoffTrigger,
519+
),
520+
max_value=70,
521+
)
522+
def _call_model(
523+
self, prompt: Union[Conversation, str], generations_this_call: int = 1
524+
) -> List[Union[Message, None]]:
525+
if self.client is None:
526+
self._load_unsafe()
527+
528+
instructions = self.instructions or self._extract_system_prompt(prompt)
529+
create_args = {
530+
"model": self.name,
531+
"input": self._build_input(prompt),
532+
"max_output_tokens": self.max_output_tokens,
533+
}
534+
if instructions:
535+
create_args["instructions"] = instructions
536+
if self.tools:
537+
create_args["tools"] = self.tools
538+
for k, v in self.extra_params.items():
539+
create_args[k] = v
540+
541+
try:
542+
response = self.client.responses.create(**create_args)
543+
except openai.BadRequestError as e:
544+
logging.exception(e)
545+
logging.error("Bad request: %s", repr(prompt))
546+
return [None]
547+
except json.decoder.JSONDecodeError as e:
548+
logging.exception(e)
549+
raise garak.exception.GeneratorBackoffTrigger from e
550+
551+
text_parts = []
552+
tool_calls = []
553+
554+
# Manually iterate response.output rather than using response.output_text so
555+
# that output_types can include non-message items such as reasoning summaries.
556+
_TOOL_CALL_ATTRS = (
557+
"call_id", "name", "arguments", "input", "output",
558+
"error", "status", "server_label",
559+
)
560+
for item in response.output:
561+
item_type = getattr(item, "type", None)
562+
# Any item whose type ends with "_call" is treated as a tool invocation and
563+
# captured in tool_calls. This covers all current
564+
# (function_call, mcp_call, web_search_call, file_search_call, computer_call)
565+
# and future tool types without needing per-type handling.
566+
if item_type is not None and item_type.endswith("_call"):
567+
entry = {"type": item_type, "id": getattr(item, "id", None)}
568+
for attr in _TOOL_CALL_ATTRS:
569+
val = getattr(item, attr, None)
570+
if val is not None:
571+
entry[attr] = val
572+
tool_calls.append(entry)
573+
continue
574+
# Only attributes that are present on the item are included, so the dict shape varies by type.
575+
if item_type not in self.output_types:
576+
continue
577+
if item_type == "message":
578+
for part in item.content:
579+
if getattr(part, "type", None) == "output_text":
580+
text_parts.append(part.text)
581+
elif item_type == "reasoning":
582+
for summary in getattr(item, "summary", []):
583+
if getattr(summary, "type", None) == "summary_text":
584+
text_parts.append(summary.text)
585+
else:
586+
text = getattr(item, "text", None)
587+
if text is not None:
588+
text_parts.append(text)
589+
else:
590+
logging.warning("No text extraction defined for output item type %r", item_type)
591+
592+
text = "\n".join(text_parts) if text_parts else None
593+
notes = {"tool_calls": tool_calls} if tool_calls else {}
594+
# Responses API doesn't support multiple choices, we'll always return a list of one Message.
595+
# Return a Message whenever there is text or tool call metadata; None only when both are absent.
596+
if text is not None or tool_calls:
597+
return [Message(text, notes=notes)]
598+
return [None]
599+
600+
434601
DEFAULT_CLASS = "OpenAIGenerator"

0 commit comments

Comments
 (0)