diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index ed16934a1f61..cd441785abf0 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -46,6 +46,7 @@ set_weave_client_global, ) from weave.trace.refs import TableRef +from weave.trace.settings import override_settings from weave.trace.vals import MissingSelfInstanceError from weave.trace.weave_client import sanitize_object_name from weave.trace_server import trace_server_interface as tsi @@ -4187,6 +4188,145 @@ def parent_op(x: int) -> int: assert child_traces == num_runs # Child was traced whenever parent was +def test_tracing_sample_rate_off_by_default_and_drops_roots(client): + random.seed(0) + executed = 0 + + @weave.op + def my_op(x: int) -> int: + nonlocal executed + executed += 1 + return x + 1 + + weave.publish(my_op) + + # Off by default: with no centralized rate set, every root is kept. + for i in range(5): + my_op(i) + assert len(list(my_op.calls())) == 5 + + # Centralized rate 0.0: the root is sampled out, but the wrapped function + # still runs — sampling only skips tracing, never the user's code. + with override_settings(tracing_sample_rate=0.0): + for i in range(5): + my_op(i) + + assert executed == 10 # function ran all ten times + assert len(list(my_op.calls())) == 5 # only the first five were traced + + +def test_tracing_sample_rate_env_var(client, monkeypatch): + @weave.op + def my_op(x: int) -> int: + return x + 1 + + weave.publish(my_op) + + monkeypatch.setenv("WEAVE_TRACING_SAMPLE_RATE", "0.0") + for i in range(5): + my_op(i) + + assert len(list(my_op.calls())) == 0 + + +def test_tracing_sample_rate_composition_is_multiplicative(client, monkeypatch): + # A per-op rate of 0.5 composed with a centralized rate of 0.5 yields an + # effective keep-rate of 0.25. A fixed random draw of 0.3 falls between the + # two thresholds: it is kept by 0.5 alone but dropped by the composed 0.25, + # which is exactly what proves the composition is multiplicative. + monkeypatch.setattr(random, "random", lambda: 0.3) + + @weave.op(tracing_sample_rate=0.5) + def half_op(x: int) -> int: + return x + 1 + + weave.publish(half_op) + + with override_settings(tracing_sample_rate=0.5): + half_op(1) # effective 0.25; 0.3 > 0.25 -> dropped + assert len(list(half_op.calls())) == 0 + + with override_settings(tracing_sample_rate=1.0): + half_op(2) # effective 0.5; 0.3 <= 0.5 -> kept + assert len(list(half_op.calls())) == 1 + + +class SamplingCarveoutModel(weave.Model): + @weave.op + def predict(self, question: str) -> dict: + return {"generated_text": question} + + +@weave.op +def sampling_carveout_score(expected: str, output: dict) -> dict: + return {"match": expected == output["generated_text"]} + + +@pytest.mark.asyncio +async def test_tracing_sample_rate_eval_carveout_declarative(client): + random.seed(0) + + @weave.op + def plain_op(x: int) -> int: + return x + 1 + + examples = [ + {"question": "a", "expected": "a"}, + {"question": "b", "expected": "x"}, + ] + evaluation = weave.Evaluation(dataset=examples, scorers=[sampling_carveout_score]) + + # A centralized rate of 0.0 would drop every root, but evaluations are exempt. + with override_settings(tracing_sample_rate=0.0): + for i in range(5): + plain_op(i) # control: a non-eval root, expected to be dropped + await evaluation.evaluate(SamplingCarveoutModel()) + + client.flush() + op_names = [ + c.op_name + for c in client.server.calls_query( + tsi.CallsQueryReq(project_id=client.project_id) + ).calls + ] + + # The control op was dropped, proving sampling is active at 0.0 ... + assert not any("plain_op" in name for name in op_names) + # ... yet the whole evaluation tree (root + children) survived. + assert any("Evaluation.evaluate" in name for name in op_names) + assert any("Evaluation.predict_and_score" in name for name in op_names) + assert any("SamplingCarveoutModel.predict" in name for name in op_names) + + +def test_tracing_sample_rate_eval_carveout_imperative(client): + random.seed(0) + + @weave.op + def plain_op(x: int) -> int: + return x + 1 + + with override_settings(tracing_sample_rate=0.0): + for i in range(5): + plain_op(i) # control: a non-eval root, expected to be dropped + + ev = weave.EvaluationLogger() + pred = ev.log_prediction(inputs={"q": "hello"}, output="world") + pred.log_score(scorer="accuracy", score=True) + pred.finish() + ev.log_summary({"accuracy_mean": 1.0}) + + client.flush() + op_names = [ + c.op_name + for c in client.server.calls_query( + tsi.CallsQueryReq(project_id=client.project_id) + ).calls + ] + + assert not any("plain_op" in name for name in op_names) + assert any("Evaluation.evaluate" in name for name in op_names) + + def test_calls_len(client): @weave.op def test(): diff --git a/tests/trace/test_trace_settings.py b/tests/trace/test_trace_settings.py index 2c387748f94c..2b63cddd064b 100644 --- a/tests/trace/test_trace_settings.py +++ b/tests/trace/test_trace_settings.py @@ -31,6 +31,7 @@ should_disable_weave, should_print_call_link, should_redact_pii, + tracing_sample_rate, ) from weave.trace.weave_client import get_parallelism_settings from weave.utils.retry import with_retry @@ -619,6 +620,33 @@ def test_parse_and_apply_settings_is_alias_for_replace_settings(self): assert should_disable_weave() is False +@pytest.mark.usefixtures("clean_settings_env") +class TestTracingSampleRate: + def test_default_is_one(self): + assert tracing_sample_rate() == 1.0 + + def test_reads_snapshot(self): + replace_settings(UserSettings(tracing_sample_rate=0.25)) + assert tracing_sample_rate() == 0.25 + + def test_clamps_above_one(self): + replace_settings(UserSettings(tracing_sample_rate=2.0)) + assert tracing_sample_rate() == 1.0 + + def test_clamps_below_zero(self): + replace_settings(UserSettings(tracing_sample_rate=-1.0)) + assert tracing_sample_rate() == 0.0 + + def test_env_coerces_to_float_and_wins(self, monkeypatch): + replace_settings(UserSettings(tracing_sample_rate=1.0)) + monkeypatch.setenv("WEAVE_TRACING_SAMPLE_RATE", "0.1") + assert tracing_sample_rate() == 0.1 + + def test_env_is_clamped(self, monkeypatch): + monkeypatch.setenv("WEAVE_TRACING_SAMPLE_RATE", "5") + assert tracing_sample_rate() == 1.0 + + class TestUserSettingsValue: def test_is_frozen(self): settings = UserSettings() diff --git a/weave/trace/op.py b/weave/trace/op.py index d4fbbc881592..dc16ea6c13eb 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -67,6 +67,7 @@ ProcessedInputs, ) from weave.trace.util import log_once +from weave.trace_server import constants if TYPE_CHECKING: from weave.trace.call import Call, CallsIter, NoOpCall @@ -417,11 +418,32 @@ def should_skip_tracing_for_op(op: Op) -> bool: return not op._tracing_enabled +def _is_sampling_exempt(op: Op) -> bool: + """Root evaluation calls are never sampled out. + + Preserving evaluations is the whole point of the carve-out: an evaluation + that silently vanished under sampling would be far more surprising than a + dropped ad-hoc trace. Both the declarative `Evaluation.evaluate` op and the + imperative `EvaluationLogger` op resolve to the same op name, so a single + name check covers both entry points. The check runs only for root calls + (see `_should_sample_traces`), so the whole eval subtree is kept. + """ + return getattr(op, "name", None) == constants.EVALUATION_RUN_OP_NAME + + def _should_sample_traces(op: Op) -> bool: if call_context.get_current_call(): return False # Don't sample traces for child calls - if random.random() > op.tracing_sample_rate: + if _is_sampling_exempt(op): + return False # Never sample out evaluation roots + + # Compose the centralized rate with the per-op rate multiplicatively: both + # express "fraction to keep", so the stricter of the two wins (e.g. global + # 0.5 and per-op 0.5 keep ~25%). Defaults are 1.0 * 1.0 = 1.0 (keep all). + effective_rate = settings.tracing_sample_rate() * op.tracing_sample_rate + + if random.random() > effective_rate: return True # Sample traces for this call return False diff --git a/weave/trace/settings.py b/weave/trace/settings.py index 8451c3acc121..14fc24396e5c 100644 --- a/weave/trace/settings.py +++ b/weave/trace/settings.py @@ -317,6 +317,23 @@ class UserSettings: Can be overridden with the environment variable `WEAVE_USE_OTEL_V2` """ + tracing_sample_rate: float = 1.0 + """ + Centralized fraction of root traces to keep, from 0.0 to 1.0. Defaults to + 1.0 (keep everything), so sampling is off unless this is set. + + The decision is made once on the root call and composed multiplicatively + with the per-op `tracing_sample_rate` decorator argument, so the stricter of + the two wins (e.g. 0.5 here and 0.5 on the op keeps ~25% of traces). Child + calls inherit the root's decision, so a trace is always kept or dropped as a + whole. Evaluation traces are never sampled out regardless of this value. + + Unlike the per-op rate, this is meant to be set in one place (for example as + a deployment-wide environment variable) so individual engineers do not each + have to opt in per op. + Can be overridden with the environment variable `WEAVE_TRACING_SAMPLE_RATE` + """ + class _SettingsOverrides(TypedDict, total=False): """Typed kwargs accepted by :func:`override_settings`. @@ -360,6 +377,7 @@ class _SettingsOverrides(TypedDict, total=False): enable_wal: bool disable_wal_sender: bool use_otel_v2: bool + tracing_sample_rate: float # Resolve string annotations once at import; used for env-var coercion. @@ -634,3 +652,11 @@ def should_disable_wal_sender() -> bool: def should_use_otel_v2() -> bool: """Returns whether OTel-capable integrations should use their OTel variant.""" return _env_or_default("use_otel_v2", _current_settings.get().use_otel_v2) + + +def tracing_sample_rate() -> float: + """Returns the centralized fraction of root traces to keep, clamped to [0.0, 1.0].""" + rate = _env_or_default( + "tracing_sample_rate", _current_settings.get().tracing_sample_rate + ) + return max(0.0, min(1.0, rate))