@@ -21,12 +21,15 @@ class LangGraphRolloutProcessor(RolloutProcessor):
2121 def __init__ (
2222 self ,
2323 * ,
24- graph_factory : Callable [[Dict [str , Any ]], Any ],
24+ # Factory must accept RolloutProcessorConfig (parity with Pydantic AI processor)
25+ graph_factory : Callable [[RolloutProcessorConfig ], Any ],
2526 to_input : Optional [Callable [[EvaluationRow ], Dict [str , Any ]]] = None ,
2627 apply_result : Optional [Callable [[EvaluationRow , Any ], EvaluationRow ]] = None ,
2728 build_graph_kwargs : Optional [Callable [[CompletionParams ], Dict [str , Any ]]] = None ,
2829 input_key : str = "messages" ,
2930 output_key : str = "messages" ,
31+ # Optional: build per-invoke RunnableConfig dict from full RolloutProcessorConfig
32+ build_invoke_config : Optional [Callable [[RolloutProcessorConfig ], Dict [str , Any ]]] = None ,
3033 ) -> None :
3134 # Build the graph per-call using completion_params
3235 self ._graph_factory = graph_factory
@@ -35,6 +38,7 @@ def __init__(
3538 self ._build_graph_kwargs = build_graph_kwargs
3639 self ._input_key = input_key
3740 self ._output_key = output_key
41+ self ._build_invoke_config = build_invoke_config
3842
3943 def _default_to_input (self , row : EvaluationRow ) -> Dict [str , Any ]:
4044 messages = row .messages or []
@@ -121,14 +125,21 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
121125 if config .completion_params :
122126 graph_config = build_kwargs (config .completion_params )
123127
124- # (Re)build the graph for this call using the graph kwargs
125- graph_target = self ._graph_factory (graph_config or {})
128+ # (Re)build the graph for this call using the full typed config.
129+ graph_target = self ._graph_factory (config )
130+
131+ # Build per-invoke config if provided; otherwise reuse graph_config for backwards compat
132+ invoke_config : Optional [Dict [str , Any ]] = None
133+ if self ._build_invoke_config is not None :
134+ invoke_config = self ._build_invoke_config (config )
135+ elif graph_config is not None :
136+ invoke_config = graph_config
126137
127138 async def _process_row (row : EvaluationRow ) -> EvaluationRow :
128139 try :
129140 payload = to_input (row )
130- if graph_config is not None :
131- result = await graph_target .ainvoke (payload , config = graph_config )
141+ if invoke_config is not None :
142+ result = await graph_target .ainvoke (payload , config = invoke_config )
132143 else :
133144 result = await graph_target .ainvoke (payload )
134145 row = apply_result (row , result )
0 commit comments