Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions tests/trace/test_client_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
set_weave_client_global,
)
from weave.trace.refs import TableRef
from weave.trace.settings import override_settings
from weave.trace.vals import MissingSelfInstanceError
from weave.trace.weave_client import sanitize_object_name
from weave.trace_server import trace_server_interface as tsi
Expand Down Expand Up @@ -4187,6 +4188,145 @@ def parent_op(x: int) -> int:
assert child_traces == num_runs # Child was traced whenever parent was


def test_tracing_sample_rate_off_by_default_and_drops_roots(client):
random.seed(0)
executed = 0

@weave.op
def my_op(x: int) -> int:
nonlocal executed
executed += 1
return x + 1

weave.publish(my_op)

# Off by default: with no centralized rate set, every root is kept.
for i in range(5):
my_op(i)
assert len(list(my_op.calls())) == 5

# Centralized rate 0.0: the root is sampled out, but the wrapped function
# still runs — sampling only skips tracing, never the user's code.
with override_settings(tracing_sample_rate=0.0):
for i in range(5):
my_op(i)

assert executed == 10 # function ran all ten times
assert len(list(my_op.calls())) == 5 # only the first five were traced


def test_tracing_sample_rate_env_var(client, monkeypatch):
@weave.op
def my_op(x: int) -> int:
return x + 1

weave.publish(my_op)

monkeypatch.setenv("WEAVE_TRACING_SAMPLE_RATE", "0.0")
for i in range(5):
my_op(i)

assert len(list(my_op.calls())) == 0


def test_tracing_sample_rate_composition_is_multiplicative(client, monkeypatch):
# A per-op rate of 0.5 composed with a centralized rate of 0.5 yields an
# effective keep-rate of 0.25. A fixed random draw of 0.3 falls between the
# two thresholds: it is kept by 0.5 alone but dropped by the composed 0.25,
# which is exactly what proves the composition is multiplicative.
monkeypatch.setattr(random, "random", lambda: 0.3)

@weave.op(tracing_sample_rate=0.5)
def half_op(x: int) -> int:
return x + 1

weave.publish(half_op)

with override_settings(tracing_sample_rate=0.5):
half_op(1) # effective 0.25; 0.3 > 0.25 -> dropped
assert len(list(half_op.calls())) == 0

with override_settings(tracing_sample_rate=1.0):
half_op(2) # effective 0.5; 0.3 <= 0.5 -> kept
assert len(list(half_op.calls())) == 1


class SamplingCarveoutModel(weave.Model):
@weave.op
def predict(self, question: str) -> dict:
return {"generated_text": question}


@weave.op
def sampling_carveout_score(expected: str, output: dict) -> dict:
return {"match": expected == output["generated_text"]}


@pytest.mark.asyncio
async def test_tracing_sample_rate_eval_carveout_declarative(client):
random.seed(0)

@weave.op
def plain_op(x: int) -> int:
return x + 1

examples = [
{"question": "a", "expected": "a"},
{"question": "b", "expected": "x"},
]
evaluation = weave.Evaluation(dataset=examples, scorers=[sampling_carveout_score])

# A centralized rate of 0.0 would drop every root, but evaluations are exempt.
with override_settings(tracing_sample_rate=0.0):
for i in range(5):
plain_op(i) # control: a non-eval root, expected to be dropped
await evaluation.evaluate(SamplingCarveoutModel())

client.flush()
op_names = [
c.op_name
for c in client.server.calls_query(
tsi.CallsQueryReq(project_id=client.project_id)
).calls
]

# The control op was dropped, proving sampling is active at 0.0 ...
assert not any("plain_op" in name for name in op_names)
# ... yet the whole evaluation tree (root + children) survived.
assert any("Evaluation.evaluate" in name for name in op_names)
assert any("Evaluation.predict_and_score" in name for name in op_names)
assert any("SamplingCarveoutModel.predict" in name for name in op_names)


def test_tracing_sample_rate_eval_carveout_imperative(client):
random.seed(0)

@weave.op
def plain_op(x: int) -> int:
return x + 1

with override_settings(tracing_sample_rate=0.0):
for i in range(5):
plain_op(i) # control: a non-eval root, expected to be dropped

ev = weave.EvaluationLogger()
pred = ev.log_prediction(inputs={"q": "hello"}, output="world")
pred.log_score(scorer="accuracy", score=True)
pred.finish()
ev.log_summary({"accuracy_mean": 1.0})

client.flush()
op_names = [
c.op_name
for c in client.server.calls_query(
tsi.CallsQueryReq(project_id=client.project_id)
).calls
]

assert not any("plain_op" in name for name in op_names)
assert any("Evaluation.evaluate" in name for name in op_names)


def test_calls_len(client):
@weave.op
def test():
Expand Down
28 changes: 28 additions & 0 deletions tests/trace/test_trace_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
should_disable_weave,
should_print_call_link,
should_redact_pii,
tracing_sample_rate,
)
from weave.trace.weave_client import get_parallelism_settings
from weave.utils.retry import with_retry
Expand Down Expand Up @@ -619,6 +620,33 @@ def test_parse_and_apply_settings_is_alias_for_replace_settings(self):
assert should_disable_weave() is False


@pytest.mark.usefixtures("clean_settings_env")
class TestTracingSampleRate:
def test_default_is_one(self):
assert tracing_sample_rate() == 1.0

def test_reads_snapshot(self):
replace_settings(UserSettings(tracing_sample_rate=0.25))
assert tracing_sample_rate() == 0.25

def test_clamps_above_one(self):
replace_settings(UserSettings(tracing_sample_rate=2.0))
assert tracing_sample_rate() == 1.0

def test_clamps_below_zero(self):
replace_settings(UserSettings(tracing_sample_rate=-1.0))
assert tracing_sample_rate() == 0.0

def test_env_coerces_to_float_and_wins(self, monkeypatch):
replace_settings(UserSettings(tracing_sample_rate=1.0))
monkeypatch.setenv("WEAVE_TRACING_SAMPLE_RATE", "0.1")
assert tracing_sample_rate() == 0.1

def test_env_is_clamped(self, monkeypatch):
monkeypatch.setenv("WEAVE_TRACING_SAMPLE_RATE", "5")
assert tracing_sample_rate() == 1.0


class TestUserSettingsValue:
def test_is_frozen(self):
settings = UserSettings()
Expand Down
24 changes: 23 additions & 1 deletion weave/trace/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
ProcessedInputs,
)
from weave.trace.util import log_once
from weave.trace_server import constants

if TYPE_CHECKING:
from weave.trace.call import Call, CallsIter, NoOpCall
Expand Down Expand Up @@ -417,11 +418,32 @@ def should_skip_tracing_for_op(op: Op) -> bool:
return not op._tracing_enabled


def _is_sampling_exempt(op: Op) -> bool:
"""Root evaluation calls are never sampled out.

Preserving evaluations is the whole point of the carve-out: an evaluation
that silently vanished under sampling would be far more surprising than a
dropped ad-hoc trace. Both the declarative `Evaluation.evaluate` op and the
imperative `EvaluationLogger` op resolve to the same op name, so a single
name check covers both entry points. The check runs only for root calls
(see `_should_sample_traces`), so the whole eval subtree is kept.
"""
return getattr(op, "name", None) == constants.EVALUATION_RUN_OP_NAME


def _should_sample_traces(op: Op) -> bool:
if call_context.get_current_call():
return False # Don't sample traces for child calls

if random.random() > op.tracing_sample_rate:
if _is_sampling_exempt(op):
return False # Never sample out evaluation roots

# Compose the centralized rate with the per-op rate multiplicatively: both
# express "fraction to keep", so the stricter of the two wins (e.g. global
# 0.5 and per-op 0.5 keep ~25%). Defaults are 1.0 * 1.0 = 1.0 (keep all).
effective_rate = settings.tracing_sample_rate() * op.tracing_sample_rate

if random.random() > effective_rate:
return True # Sample traces for this call

return False
Expand Down
26 changes: 26 additions & 0 deletions weave/trace/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,23 @@ class UserSettings:
Can be overridden with the environment variable `WEAVE_USE_OTEL_V2`
"""

tracing_sample_rate: float = 1.0
"""
Centralized fraction of root traces to keep, from 0.0 to 1.0. Defaults to
1.0 (keep everything), so sampling is off unless this is set.

The decision is made once on the root call and composed multiplicatively
with the per-op `tracing_sample_rate` decorator argument, so the stricter of
the two wins (e.g. 0.5 here and 0.5 on the op keeps ~25% of traces). Child
calls inherit the root's decision, so a trace is always kept or dropped as a
whole. Evaluation traces are never sampled out regardless of this value.

Unlike the per-op rate, this is meant to be set in one place (for example as
a deployment-wide environment variable) so individual engineers do not each
have to opt in per op.
Can be overridden with the environment variable `WEAVE_TRACING_SAMPLE_RATE`
"""


class _SettingsOverrides(TypedDict, total=False):
"""Typed kwargs accepted by :func:`override_settings`.
Expand Down Expand Up @@ -360,6 +377,7 @@ class _SettingsOverrides(TypedDict, total=False):
enable_wal: bool
disable_wal_sender: bool
use_otel_v2: bool
tracing_sample_rate: float


# Resolve string annotations once at import; used for env-var coercion.
Expand Down Expand Up @@ -634,3 +652,11 @@ def should_disable_wal_sender() -> bool:
def should_use_otel_v2() -> bool:
"""Returns whether OTel-capable integrations should use their OTel variant."""
return _env_or_default("use_otel_v2", _current_settings.get().use_otel_v2)


def tracing_sample_rate() -> float:
"""Returns the centralized fraction of root traces to keep, clamped to [0.0, 1.0]."""
rate = _env_or_default(
"tracing_sample_rate", _current_settings.get().tracing_sample_rate
)
return max(0.0, min(1.0, rate))
Loading