Skip to content

Commit f170b08

Browse files
authored
fix(nodes): expose model identity in node event inputs (#80)
1 parent c9cb948 commit f170b08

7 files changed

Lines changed: 171 additions & 3 deletions

File tree

src/graphon/nodes/llm/llm_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,17 @@ def fetch_model_schema(*, model_instance: PreparedLLMProtocol) -> AIModelEntity:
6767
return model_schema
6868

6969

70+
def build_model_identity_inputs(
71+
*,
72+
model_instance: PreparedLLMProtocol,
73+
) -> dict[str, Any]:
74+
"""Expose the prepared model identity in node inputs."""
75+
return {
76+
"model_provider": model_instance.provider,
77+
"model_name": model_instance.model_name,
78+
}
79+
80+
7081
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence[File]:
7182
variable = variable_pool.get(selector)
7283
if variable is None:

src/graphon/nodes/llm/node.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,9 @@ def _prepare_run_prompt(
253253
collected_context=collected_context,
254254
)
255255
model_instance = self._prepare_model_instance()
256+
node_inputs.update(
257+
llm_utils.build_model_identity_inputs(model_instance=model_instance),
258+
)
256259
prompt_messages, stop = LLMNode.fetch_prompt_messages(
257260
sys_query=self._resolve_memory_query(),
258261
sys_files=files,
@@ -377,6 +380,9 @@ def _yield_run_completion(
377380
)
378381
break
379382

383+
node_inputs.update(
384+
llm_utils.build_model_identity_inputs(model_instance=self._model_instance),
385+
)
380386
process_data.update(
381387
self._build_process_data(
382388
prompt_messages=prompt_messages,

src/graphon/nodes/parameter_extractor/parameter_extractor_node.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def _prepare_run_context(self) -> _ParameterExtractorRunContext:
384384
"files": [f.to_dict() for f in files],
385385
"parameters": jsonable_encoder(node_data.parameters),
386386
"instruction": jsonable_encoder(node_data.instruction),
387+
**llm_utils.build_model_identity_inputs(model_instance=model_instance),
387388
}
388389
process_data = {
389390
"model_mode": node_data.model.mode,

src/graphon/nodes/question_classifier/question_classifier_node.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,18 @@ def _prepare_run_context(self) -> _QuestionClassifierRunContext:
275275
model_instance=model_instance,
276276
files=files,
277277
)
278+
inputs = {
279+
"query": query,
280+
**llm_utils.build_model_identity_inputs(model_instance=model_instance),
281+
}
278282
rendered_classes = [
279283
class_.model_copy(
280284
update={"name": variable_pool.convert_template(class_.name).text},
281285
)
282286
for class_ in node_data.classes
283287
]
284288
return _QuestionClassifierRunContext(
285-
inputs={"query": query},
289+
inputs=inputs,
286290
model_instance=model_instance,
287291
prompt_messages=prompt_messages,
288292
stop=stop,

tests/nodes/llm/test_node.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
5+
from graphon.model_runtime.entities.llm_entities import LLMUsage
6+
from graphon.node_events.node import ModelInvokeCompletedEvent, StreamCompletedEvent
7+
from graphon.nodes.llm import LLMNode, LLMNodeData
8+
from graphon.runtime.graph_runtime_state import GraphRuntimeState
9+
10+
from ...helpers import build_graph_init_params, build_variable_pool
11+
12+
13+
def _build_llm_node() -> LLMNode:
14+
return LLMNode(
15+
node_id="llm",
16+
config=LLMNodeData.model_validate({
17+
"title": "LLM",
18+
"model": {
19+
"provider": "openai",
20+
"name": "gpt-4o",
21+
"mode": "chat",
22+
"completion_params": {},
23+
},
24+
"prompt_template": [
25+
{
26+
"role": "user",
27+
"text": "Hello",
28+
}
29+
],
30+
"context": {"enabled": False},
31+
}),
32+
graph_init_params=build_graph_init_params(
33+
graph_config={"nodes": [], "edges": []}
34+
),
35+
graph_runtime_state=GraphRuntimeState(
36+
variable_pool=build_variable_pool(),
37+
start_at=0.0,
38+
),
39+
model_instance=MagicMock(
40+
provider="openai",
41+
model_name="gpt-4o",
42+
parameters={},
43+
stop=(),
44+
),
45+
llm_file_saver=MagicMock(),
46+
prompt_message_serializer=MagicMock(),
47+
)
48+
49+
50+
def test_run_emits_model_identity_in_node_result_inputs(
51+
monkeypatch: pytest.MonkeyPatch,
52+
) -> None:
53+
node = _build_llm_node()
54+
55+
monkeypatch.setattr(node, "_fetch_inputs", lambda **_: {})
56+
monkeypatch.setattr(node, "_fetch_jinja_inputs", lambda **_: {})
57+
monkeypatch.setattr(node, "_collect_run_context", lambda **_: iter(()))
58+
monkeypatch.setattr(
59+
LLMNode, "fetch_prompt_messages", staticmethod(lambda **_: ([], None))
60+
)
61+
monkeypatch.setattr(
62+
"graphon.nodes.llm.node.LLMNode.invoke_llm",
63+
lambda **_: iter([
64+
ModelInvokeCompletedEvent(
65+
text="Hello back",
66+
usage=LLMUsage.empty_usage(),
67+
finish_reason="stop",
68+
),
69+
]),
70+
)
71+
72+
events = list(node._run()) # noqa: SLF001
73+
completed_event = next(
74+
event for event in events if isinstance(event, StreamCompletedEvent)
75+
)
76+
77+
assert completed_event.node_run_result.inputs["model_provider"] == "openai"
78+
assert completed_event.node_run_result.inputs["model_name"] == "gpt-4o"

tests/nodes/parameter_extractor/test_prompts.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import time
2+
from types import SimpleNamespace
23
from typing import Any, cast
34
from unittest.mock import Mock
45

56
import pytest
67

7-
from graphon.model_runtime.entities.llm_entities import LLMMode
8+
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMUsage
89
from graphon.model_runtime.entities.message_entities import PromptMessageRole
10+
from graphon.nodes.llm import llm_utils
911
from graphon.nodes.llm.entities import ModelConfig
1012
from graphon.nodes.parameter_extractor import parameter_extractor_node
1113
from graphon.nodes.parameter_extractor.entities import (
@@ -47,7 +49,12 @@ def _build_parameter_extractor_node() -> tuple[ParameterExtractorNode, VariableP
4749
start_at=time.perf_counter(),
4850
)
4951
init_params = build_graph_init_params(graph_config={"nodes": [], "edges": []})
50-
model_instance = Mock()
52+
model_instance = Mock(
53+
provider="test",
54+
model_name="test-model",
55+
parameters={},
56+
stop=(),
57+
)
5158
prompt_message_serializer = Mock()
5259
node = cast(
5360
ParameterExtractorNode,
@@ -151,6 +158,65 @@ def test_function_calling_prompt_template_renders_system_message() -> None:
151158
assert prompt_messages[1].text == "Extract the location from this request."
152159

153160

161+
def test_prepare_run_context_exposes_model_identity_in_inputs(
162+
monkeypatch: pytest.MonkeyPatch,
163+
) -> None:
164+
node, variable_pool = _build_parameter_extractor_node()
165+
variable_pool.add(("start", "query"), "weather in sf")
166+
167+
monkeypatch.setattr(
168+
llm_utils,
169+
"resolve_completion_params_variables",
170+
lambda parameters, _: parameters,
171+
)
172+
monkeypatch.setattr(
173+
node,
174+
"_fetch_llm_model_schema",
175+
lambda **_: SimpleNamespace(features=[]),
176+
)
177+
monkeypatch.setattr(node, "_build_run_prompt", lambda **_: ([], []))
178+
179+
run_context = node._prepare_run_context() # noqa: SLF001
180+
181+
assert run_context.inputs["query"] == "weather in sf"
182+
assert run_context.inputs["model_provider"] == "test"
183+
assert run_context.inputs["model_name"] == "test-model"
184+
185+
186+
def test_parameter_extractor_run_emits_model_identity_in_inputs(
187+
monkeypatch: pytest.MonkeyPatch,
188+
) -> None:
189+
node, variable_pool = _build_parameter_extractor_node()
190+
variable_pool.add(("start", "query"), "weather in sf")
191+
192+
monkeypatch.setattr(
193+
llm_utils,
194+
"resolve_completion_params_variables",
195+
lambda parameters, _: parameters,
196+
)
197+
monkeypatch.setattr(
198+
node,
199+
"_fetch_llm_model_schema",
200+
lambda **_: SimpleNamespace(features=[]),
201+
)
202+
monkeypatch.setattr(node, "_build_run_prompt", lambda **_: ([], []))
203+
204+
invoke_result = SimpleNamespace(
205+
usage=LLMUsage.empty_usage(),
206+
message=SimpleNamespace(
207+
get_text_content=lambda: "{}",
208+
tool_calls=[],
209+
),
210+
)
211+
monkeypatch.setattr(node.model_instance, "invoke_llm", lambda **_: invoke_result)
212+
213+
result = node._run() # noqa: SLF001
214+
215+
assert result.inputs["query"] == "weather in sf"
216+
assert result.inputs["model_provider"] == "test"
217+
assert result.inputs["model_name"] == "test-model"
218+
219+
154220
def test_parameter_extractor_accepts_dependency_bundle() -> None:
155221
variable_pool = build_variable_pool(variables=[])
156222
runtime_state = GraphRuntimeState(

tests/nodes/question_classifier/test_question_classifier_node.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ def test_question_classifier_run_returns_custom_class_label(
300300
assert result.outputs["class_name"] == "Questions about refunds"
301301
assert result.outputs["class_label"] == "Refund desk"
302302
assert result.outputs["class_id"] == "refund"
303+
assert result.inputs["model_provider"] == "openai"
304+
assert result.inputs["model_name"] == "gpt-4o"
303305

304306

305307
def test_question_classifier_run_falls_back_to_canonical_class_label(

0 commit comments

Comments
 (0)