Skip to content

Commit aec0975

Browse files
Weave: Group workflow traces under the parent evaluation call (#663)
When an Evaluation run is started the eval_call id is pushed to the call stack so subsequent traces can be grouped underneath it. This allows the user to debug the predictions, scores and traces for each run easily. - This PR is a workaround as the eval (score) exporting will be migrated to the observability exporter_manager in the future - This workaround only works for local eval. Remote eval is yet to be solved. - This PR also adds a change to wait on the trace export to finish before finishing the evaluation run <img width="2044" height="907" alt="image" src="https://github.com/user-attachments/assets/4f169a63-7152-4b17-a9ea-bdf6ffda9218" /> ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/NVIDIA/NeMo-Agent-Toolkit/blob/develop/docs/source/resources/contributing.md). - We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. - Any contribution which contains commits that are not Signed-Off will not be accepted. - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - Anuradha Karuppiah (https://github.com/AnuradhaKaruppiah) Approvers: - Matthew Penn (https://github.com/mpenn) URL: #663
1 parent 992e59e commit aec0975

12 files changed

Lines changed: 313 additions & 53 deletions

File tree

packages/nvidia_nat_opentelemetry/tests/observability/test_otel_span_adapter_exporter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ async def test_end_to_end_span_processing(self, basic_exporter_config, sample_st
241241
exporter.export(sample_end_event)
242242

243243
# Wait for async processing
244-
await exporter._wait_for_tasks()
244+
await exporter.wait_for_tasks()
245245

246246
# Verify that export was called (span was processed and exported)
247247
mock_otlp_exporter.export.assert_called()
@@ -295,7 +295,7 @@ async def test_batching_behavior(self, mock_otlp_exporter_class, basic_exporter_
295295
exporter.export(end_event)
296296

297297
# Wait for batch processing
298-
await exporter._wait_for_tasks()
298+
await exporter.wait_for_tasks()
299299

300300
# Verify that export was called (batching should trigger export)
301301
mock_otlp_exporter.export.assert_called()

packages/nvidia_nat_opentelemetry/tests/observability/test_otel_span_adapter_integration.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ async def test_actual_span_export_to_mock_server(self, mock_otlp_server, sample_
124124
exporter.export(end_event)
125125

126126
# Wait for async export to complete
127-
await exporter._wait_for_tasks()
127+
await exporter.wait_for_tasks()
128128

129129
# Give a small buffer for HTTP request to complete
130130
await asyncio.sleep(0.1)
@@ -158,7 +158,7 @@ async def test_export_error_handling_with_real_endpoint(self, sample_events):
158158
exporter.export(end_event)
159159

160160
# Wait for export attempt (should fail but not crash)
161-
await exporter._wait_for_tasks()
161+
await exporter.wait_for_tasks()
162162
await asyncio.sleep(0.1)
163163

164164
# Test passes if no exception was raised - error should be logged internally
@@ -202,7 +202,7 @@ async def test_span_batching_with_real_export(self, mock_otlp_server):
202202
exporter.export(end_event)
203203

204204
# Wait for batch processing
205-
await exporter._wait_for_tasks()
205+
await exporter.wait_for_tasks()
206206
await asyncio.sleep(0.1)
207207

208208
# Validate that batch export occurred
@@ -219,7 +219,7 @@ async def test_basic_export_functionality(self, mock_otlp_server, sample_events)
219219
async with exporter.start():
220220
exporter.export(start_event)
221221
exporter.export(end_event)
222-
await exporter._wait_for_tasks()
222+
await exporter.wait_for_tasks()
223223
await asyncio.sleep(0.1)
224224

225225
# Validate that spans were exported

src/nat/builder/workflow.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def has_single_output(self) -> bool:
8383

8484
return self._entry_fn.has_single_output
8585

86+
async def get_all_exporters(self) -> dict[str, BaseExporter]:
87+
return await self._exporter_manager.get_all_exporters()
88+
8689
@asynccontextmanager
8790
async def run(self, message: InputT):
8891
"""

src/nat/eval/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class EvaluationRunConfig(BaseModel):
4444
# number of passes at each concurrency, if 0 the dataset is adjusted to a multiple of the
4545
# concurrency. The is only used if adjust_dataset_size is true
4646
num_passes: int = 0
47+
# timeout for waiting for trace export tasks to complete
48+
export_timeout: float = 60.0
4749

4850

4951
class EvaluationRunOutput(BaseModel):

src/nat/eval/evaluate.py

Lines changed: 68 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,16 @@ def __init__(self, config: EvaluationRunConfig):
6363

6464
# Helpers
6565
self.intermediate_step_adapter: IntermediateStepAdapter = IntermediateStepAdapter()
66-
self.weave_eval: WeaveEvaluationIntegration = WeaveEvaluationIntegration()
66+
67+
# Create evaluation trace context
68+
try:
69+
from nat.eval.utils.eval_trace_ctx import WeaveEvalTraceContext
70+
self.eval_trace_context = WeaveEvalTraceContext()
71+
except Exception:
72+
from nat.eval.utils.eval_trace_ctx import EvalTraceContext
73+
self.eval_trace_context = EvalTraceContext()
74+
75+
self.weave_eval: WeaveEvaluationIntegration = WeaveEvaluationIntegration(self.eval_trace_context)
6776
# Metadata
6877
self.eval_input: EvalInput | None = None
6978
self.workflow_interrupted: bool = False
@@ -401,6 +410,33 @@ def _get_workflow_alias(self, workflow_type: str | None = None):
401410

402411
return workflow_type
403412

413+
async def wait_for_all_export_tasks_local(self, session_manager: SessionManager, timeout: float) -> None:
414+
"""Wait for all trace export tasks to complete for local workflows.
415+
416+
This only works for local workflows where we have direct access to the
417+
SessionManager and its underlying workflow with exporter manager.
418+
"""
419+
try:
420+
workflow = session_manager.workflow
421+
all_exporters = await workflow.get_all_exporters()
422+
if not all_exporters:
423+
logger.debug("No exporters to wait for")
424+
return
425+
426+
logger.info("Waiting for export tasks from %d local exporters (timeout: %ds)", len(all_exporters), timeout)
427+
428+
for name, exporter in all_exporters.items():
429+
try:
430+
await exporter.wait_for_tasks(timeout=timeout)
431+
logger.info("Export tasks completed for exporter: %s", name)
432+
except Exception as e:
433+
logger.warning("Error waiting for export tasks from %s: %s", name, e)
434+
435+
logger.info("All local export task waiting completed")
436+
437+
except Exception as e:
438+
logger.warning("Failed to wait for local export tasks: %s", e)
439+
404440
async def run_and_evaluate(self,
405441
session_manager: SessionManager | None = None,
406442
job_id: str | None = None) -> EvaluationRunOutput:
@@ -442,11 +478,13 @@ async def run_and_evaluate(self,
442478
dataset_config = self.eval_config.general.dataset # Currently only one dataset is supported
443479
if not dataset_config:
444480
logger.info("No dataset found, nothing to evaluate")
445-
return EvaluationRunOutput(
446-
workflow_output_file=self.workflow_output_file,
447-
evaluator_output_files=self.evaluator_output_files,
448-
workflow_interrupted=self.workflow_interrupted,
449-
)
481+
return EvaluationRunOutput(workflow_output_file=self.workflow_output_file,
482+
evaluator_output_files=self.evaluator_output_files,
483+
workflow_interrupted=self.workflow_interrupted,
484+
eval_input=EvalInput(eval_input_items=[]),
485+
evaluation_results=[],
486+
usage_stats=UsageStats(),
487+
profiler_results=ProfilerResults())
450488

451489
dataset_handler = DatasetHandler(dataset_config=dataset_config,
452490
reps=self.config.reps,
@@ -456,30 +494,37 @@ async def run_and_evaluate(self,
456494
self.eval_input = dataset_handler.get_eval_input_from_dataset(self.config.dataset)
457495
if not self.eval_input.eval_input_items:
458496
logger.info("Dataset is empty. Nothing to evaluate.")
459-
return EvaluationRunOutput(
460-
workflow_output_file=self.workflow_output_file,
461-
evaluator_output_files=self.evaluator_output_files,
462-
workflow_interrupted=self.workflow_interrupted,
463-
)
497+
return EvaluationRunOutput(workflow_output_file=self.workflow_output_file,
498+
evaluator_output_files=self.evaluator_output_files,
499+
workflow_interrupted=self.workflow_interrupted,
500+
eval_input=self.eval_input,
501+
evaluation_results=self.evaluation_results,
502+
usage_stats=self.usage_stats,
503+
profiler_results=ProfilerResults())
464504

465505
# Run workflow and evaluate
466506
async with WorkflowEvalBuilder.from_config(config=config) as eval_workflow:
467507
# Initialize Weave integration
468508
self.weave_eval.initialize_logger(workflow_alias, self.eval_input, config)
469509

470510
# Run workflow
471-
if self.config.endpoint:
472-
await self.run_workflow_remote()
473-
else:
474-
if not self.config.skip_workflow:
475-
if session_manager is None:
476-
session_manager = SessionManager(eval_workflow.build(),
477-
max_concurrency=self.eval_config.general.max_concurrency)
478-
await self.run_workflow_local(session_manager)
479-
480-
# Evaluate
481-
evaluators = {name: eval_workflow.get_evaluator(name) for name in self.eval_config.evaluators}
482-
await self.run_evaluators(evaluators)
511+
with self.eval_trace_context.evaluation_context():
512+
if self.config.endpoint:
513+
await self.run_workflow_remote()
514+
else:
515+
if not self.config.skip_workflow:
516+
if session_manager is None:
517+
session_manager = SessionManager(eval_workflow.build(),
518+
max_concurrency=self.eval_config.general.max_concurrency)
519+
await self.run_workflow_local(session_manager)
520+
521+
# Evaluate
522+
evaluators = {name: eval_workflow.get_evaluator(name) for name in self.eval_config.evaluators}
523+
await self.run_evaluators(evaluators)
524+
525+
# Wait for all trace export tasks to complete (local workflows only)
526+
if session_manager and not self.config.endpoint:
527+
await self.wait_for_all_export_tasks_local(session_manager, timeout=self.config.export_timeout)
483528

484529
# Profile the workflow
485530
profiler_results = await self.profile_workflow()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
from collections.abc import Callable
18+
from contextlib import contextmanager
19+
from typing import Any
20+
21+
logger = logging.getLogger(__name__)
22+
23+
# Type alias for evaluation call objects that have an optional 'id' attribute
24+
EvalCallType = Any # Could be Weave Call object or other tracing framework objects
25+
26+
27+
class EvalTraceContext:
28+
"""
29+
Evaluation trace context manager for coordinating traces.
30+
31+
This class provides a framework-agnostic way to:
32+
1. Track evaluation calls/contexts
33+
2. Ensure proper parent-child relationships in traces
34+
"""
35+
36+
def __init__(self):
37+
self.eval_call: EvalCallType | None = None # Store the evaluation call/context for propagation
38+
39+
def set_eval_call(self, eval_call: EvalCallType | None) -> None:
40+
"""Set the evaluation call/context for propagation to traces."""
41+
self.eval_call = eval_call
42+
if eval_call:
43+
logger.debug("Set evaluation call context: %s", getattr(eval_call, 'id', str(eval_call)))
44+
45+
def get_eval_call(self) -> EvalCallType | None:
46+
"""Get the current evaluation call/context."""
47+
return self.eval_call
48+
49+
@contextmanager
50+
def evaluation_context(self):
51+
"""
52+
Context manager that can be overridden by framework-specific implementations.
53+
Default implementation is a no-op.
54+
"""
55+
yield
56+
57+
58+
class WeaveEvalTraceContext(EvalTraceContext):
59+
"""
60+
Weave-specific implementation of evaluation trace context.
61+
"""
62+
63+
def __init__(self):
64+
super().__init__()
65+
self.available = False
66+
self.set_call_stack: Callable[[list[EvalCallType]], Any] | None = None
67+
68+
try:
69+
from weave.trace.context.call_context import set_call_stack
70+
self.set_call_stack = set_call_stack
71+
self.available = True
72+
except ImportError:
73+
self.available = False
74+
logger.debug("Weave not available for trace context")
75+
76+
@contextmanager
77+
def evaluation_context(self):
78+
"""Set the evaluation call as active context for Weave traces."""
79+
if self.available and self.eval_call and self.set_call_stack:
80+
try:
81+
with self.set_call_stack([self.eval_call]):
82+
logger.debug("Set Weave evaluation call context: %s",
83+
getattr(self.eval_call, 'id', str(self.eval_call)))
84+
yield
85+
except Exception as e:
86+
logger.warning("Failed to set Weave evaluation call context: %s", e)
87+
yield
88+
else:
89+
yield

src/nat/eval/utils/weave_eval.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import asyncio
1717
import logging
18+
from typing import TYPE_CHECKING
1819
from typing import Any
1920

2021
from nat.eval.evaluator.evaluator_model import EvalInput
@@ -24,6 +25,9 @@
2425
from nat.eval.usage_stats import UsageStatsItem
2526
from nat.profiler.data_models import ProfilerResults
2627

28+
if TYPE_CHECKING:
29+
from nat.eval.utils.eval_trace_ctx import EvalTraceContext
30+
2731
logger = logging.getLogger(__name__)
2832

2933

@@ -32,18 +36,19 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
3236
Class to handle all Weave integration functionality.
3337
"""
3438

35-
def __init__(self):
39+
def __init__(self, eval_trace_context: "EvalTraceContext"):
3640
self.available = False
3741
self.client = None
3842
self.eval_logger = None
3943
self.pred_loggers = {}
44+
self.eval_trace_context = eval_trace_context
4045

4146
try:
4247
from weave.flow.eval_imperative import EvaluationLogger
4348
from weave.flow.eval_imperative import ScoreLogger
4449
from weave.trace.context import weave_client_context
45-
self.EvaluationLogger = EvaluationLogger
46-
self.ScoreLogger = ScoreLogger
50+
self.evaluation_logger_cls = EvaluationLogger # pylint: disable=invalid-name
51+
self.score_logger_cls = ScoreLogger # pylint: disable=invalid-name
4752
self.weave_client_context = weave_client_context
4853
self.available = True
4954
except ImportError:
@@ -89,9 +94,12 @@ def initialize_logger(self, workflow_alias: str, eval_input: EvalInput, config:
8994
weave_dataset = self._get_weave_dataset(eval_input)
9095
config_dict = config.model_dump(mode="json")
9196
config_dict["name"] = workflow_alias
92-
self.eval_logger = self.EvaluationLogger(model=config_dict, dataset=weave_dataset)
97+
self.eval_logger = self.evaluation_logger_cls(model=config_dict, dataset=weave_dataset)
9398
self.pred_loggers = {}
9499

100+
# Capture the current evaluation call for context propagation
101+
self.eval_trace_context.set_eval_call(self.eval_logger._evaluate_call)
102+
95103
return True
96104
except Exception as e:
97105
self.eval_logger = None
@@ -137,7 +145,7 @@ async def alog_score(self, eval_output: EvalOutput, evaluator_name: str):
137145
await asyncio.gather(*coros)
138146

139147
async def afinish_loggers(self):
140-
"""Finish all prediction loggers."""
148+
"""Finish all prediction loggers and wait for exports."""
141149
if not self.eval_logger:
142150
return
143151

@@ -157,7 +165,6 @@ def _log_profiler_metrics(self, profiler_results: ProfilerResults, usage_stats:
157165
if profiler_results.workflow_runtime_metrics:
158166
profile_metrics["wf_runtime_p95"] = profiler_results.workflow_runtime_metrics.p95
159167

160-
# TODO:get the LLM tokens from the usage stats and log them
161168
profile_metrics["total_runtime"] = usage_stats.total_runtime
162169

163170
return profile_metrics

src/nat/observability/exporter/base_exporter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ async def _cancel_tasks(self):
357357
except Exception as e:
358358
logger.warning("Error while canceling task %s: %s", task.get_name(), e)
359359

360-
async def _wait_for_tasks(self, timeout: float = 5.0):
360+
async def wait_for_tasks(self, timeout: float = 5.0):
361361
"""Wait for all tracked tasks to complete with a timeout.
362362
363363
Note: This method is NOT called during normal stop() operation for performance.

0 commit comments

Comments
 (0)