Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ _Important:_ For OpenAI Codex agents (most likely you!), your environment does n
- Testing is managed by `nox` with multiple shards for different Python versions
- Each shard represents specific package configurations

### Server fixtures (do NOT hand-roll fake servers)

- **Never create a `_FakeServer`, stub, or mock `TraceServerInterface`** to test
server-side logic. Use the existing `client` fixture (gives `client.server` +
`client.project_id`) or the `trace_server` fixture, which run against a real
SQLite/ClickHouse backend. Build inputs with the real APIs
(`obj_create`, `table_create`, etc.). Mock only external services we don't own.

### Key Test Shards

Focus on these primary test shards:
Expand Down
89 changes: 89 additions & 0 deletions tests/trace_server/test_trace_server_evaluation_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tests.trace.util import client_is_sqlite
from tests.trace_server.completions_util import with_simple_mock_litellm_completion
from weave.trace.refs import ObjectRef
from weave.trace.serialization.custom_objs import UnsafeDeserializationError
from weave.trace.weave_client import WeaveClient, generate_id
from weave.trace_server.errors import InvalidRequest
from weave.trace_server.interface.query import Query
Expand All @@ -28,6 +29,7 @@
EvaluationStatusNotFound,
EvaluationStatusReq,
EvaluationStatusRunning,
FileCreateReq,
GenAISpanRef,
ObjCreateReq,
PredictionCreateReq,
Expand Down Expand Up @@ -406,6 +408,93 @@ def evaluate_model_wrapped(req: EvaluateModelReq):
}


# The guard raises inside the lazy row-decode threadpool, which logs the failure
# at ERROR before it propagates out of asyncio.run; that log is the expected path.
@pytest.mark.disable_logging_error_check
def test_evaluate_model_rejects_unsafe_dataset_row(client):
"""An Op node in a dataset row must be refused at decode time, not loaded and
executed (WB-34909).

The evaluation and dataset objects carry no custom types themselves; the Op
CustomWeaveType lives in a table row that is only fetched and deserialized
lazily during evaluation. The worker disables unsafe custom-object decode, so
materializing that row raises instead of importing the code.
"""
project_id = client.project_id
entity, project = from_project_id(project_id)
server = client.server

def _obj(object_id: str, val: dict, builtin: str | None = None) -> str:
obj = {"project_id": project_id, "object_id": object_id, "val": val}
if builtin is not None:
obj["builtin_object_class"] = builtin
res = server.obj_create(ObjCreateReq.model_validate({"obj": obj}))
return ObjectRef(
entity=entity, project=project, name=object_id, _digest=res.digest
).uri

# A real file so the lazy file fetch succeeds; the decode guard, not a missing
# file, is what must stop the Op row from being reconstructed.
file_res = server.file_create(
FileCreateReq(project_id=project_id, name="obj.py", content=b"print('hi')")
)
table_res = server.table_create(
TableCreateReq.model_validate(
{
"table": {
"project_id": project_id,
"rows": [
{"input": "ok"},
{
"input": {
"_type": "CustomWeaveType",
"weave_type": {"type": "Op"},
"files": {"obj.py": file_res.digest},
"load_op": None,
}
},
],
}
}
)
)
dataset_ref = _obj(
"dataset_with_custom_row",
{
"_type": "Dataset",
"_class_name": "Dataset",
"_bases": ["BaseModel", "Object", "Dataset"],
"rows": f"weave:///{project_id}/table/{table_res.digest}",
},
)
evaluation_ref = _obj(
"eval_with_custom_row",
{
"_type": "Evaluation",
"_class_name": "Evaluation",
"_bases": ["BaseModel", "Object", "Evaluation"],
"dataset": dataset_ref,
"scorers": None,
},
)
model_ref = _obj(
"valid_model",
{"llm_model_id": "gpt-4o-mini", "default_params": {}},
builtin="LLMStructuredCompletionModel",
)

with pytest.raises(UnsafeDeserializationError):
evaluate_model_worker.evaluate_model(
evaluate_model_worker.EvaluateModelArgs(
project_id=project_id,
evaluation_ref=evaluation_ref,
model_ref=model_ref,
wb_user_id=entity,
evaluation_call_id=generate_id(),
)
)


def test_eval_results_query_basic(client):
project_id = client.project_id
entity, project = from_project_id(project_id)
Expand Down
106 changes: 37 additions & 69 deletions tests/trace_server/workers/test_evaluate_model_worker.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,46 @@
import pytest

from weave.trace_server.validation import (
UnsafePayloadError,
assert_safe_payload,
from weave.trace.serialization.custom_objs import (
KNOWN_TYPES,
SAFE_CUSTOM_WEAVE_TYPES,
is_safe_to_decode,
)
from weave.trace_server.workers.evaluate_model_worker.evaluate_model_worker import (
_assert_object_ref,
)


def test_assert_safe_payload():
# Safe payloads pass
assert_safe_payload({"name": "test", "value": 42})
assert_safe_payload({"nested": {"list": [1, "two", {"three": 3}]}})
assert_safe_payload("just a string ref")
assert_safe_payload(None)
assert_safe_payload([1, 2, 3])
assert_safe_payload({"_type": "ObjectRecord", "name": "test"})

# ALL CustomWeaveType payloads rejected — Op, unknown, and safe subtypes alike
with pytest.raises(UnsafePayloadError):
assert_safe_payload(
{
"_type": "CustomWeaveType",
"weave_type": {"type": "Op"},
"files": {"obj.py": "abc123"},
}
)

with pytest.raises(UnsafePayloadError):
assert_safe_payload(
{
"_type": "CustomWeaveType",
"weave_type": {"type": "PIL.Image.Image"},
"files": {"image.png": "abc123"},
}
)
@pytest.mark.parametrize(
("type_id", "load_op", "allow_unsafe", "expected"),
[
# allow_unsafe (normal client) -> anything decodes.
("Op", None, True, True),
("TotallyMadeUp", None, True, True),
# Worker client (allow_unsafe=False): only data-only serializers, no load_op.
("Op", None, False, False),
("TotallyMadeUp", None, False, False),
("PIL.Image.Image", None, False, True),
("weave.type_wrappers.Content.content.Content", None, False, True),
# A load_op routes through the fallback code path even for known types.
("PIL.Image.Image", "weave:///e/p/op/x:1", False, False),
],
)
def test_is_safe_to_decode(type_id, load_op, allow_unsafe, expected):
assert is_safe_to_decode(type_id, load_op, allow_unsafe=allow_unsafe) is expected

with pytest.raises(UnsafePayloadError):
assert_safe_payload(
{
"_type": "CustomWeaveType",
"weave_type": {"type": "TotallyMadeUpType"},
"load_op": "weave:///entity/project/object/evil:abc123",
}
)

# Nested in a list
with pytest.raises(UnsafePayloadError):
assert_safe_payload(
{
"scorers": [
{
"_type": "CustomWeaveType",
"weave_type": {"type": "Op"},
}
],
}
)
def test_safe_custom_weave_types_in_sync():
# Every known custom type must be classified: safe data-only serializers go in
# SAFE_CUSTOM_WEAVE_TYPES, and "Op" is the lone code-loading type. A newly
# added KNOWN_TYPE fails here until it is consciously placed on one side.
assert SAFE_CUSTOM_WEAVE_TYPES | {"Op"} == set(KNOWN_TYPES)

# Deeply nested in dicts
with pytest.raises(UnsafePayloadError):
assert_safe_payload(
{
"a": {
"b": {
"c": {
"_type": "CustomWeaveType",
"weave_type": {"type": "Op"},
}
}
},
}
)

# Missing/malformed weave_type still rejected
with pytest.raises(UnsafePayloadError):
assert_safe_payload({"_type": "CustomWeaveType"})
with pytest.raises(UnsafePayloadError):
assert_safe_payload({"_type": "CustomWeaveType", "weave_type": "not a dict"})
def test_assert_object_ref_rejects_op_and_non_object_refs():
# Op ref would be loaded/executed by client.get; table/other refs aren't models.
with pytest.raises(TypeError):
_assert_object_ref("weave:///ent/proj/op/some_op:abc123", "evaluation_ref")
with pytest.raises(TypeError):
_assert_object_ref("weave:///ent/proj/table/abc123", "evaluation_ref")
# A plain object ref is accepted.
_assert_object_ref("weave:///ent/proj/object/MyEval:abc123", "evaluation_ref")
56 changes: 55 additions & 1 deletion weave/trace/serialization/custom_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from collections.abc import Mapping
from typing import Any, Literal, TypedDict

from weave.trace.context.weave_client_context import require_weave_client
from weave.trace.context.weave_client_context import (
get_weave_client,
require_weave_client,
)
from weave.trace.op import op
from weave.trace.op_protocol import Op
from weave.trace.refs import ObjectRef
Expand Down Expand Up @@ -65,6 +68,48 @@ class EncodedCustomObjDict(
"weave.type_wrappers.Content.content.Content",
}

# The one type whose serializer loads code (imports a user-uploaded `.py`).
OP_CUSTOM_WEAVE_TYPE = "Op"

# Custom types whose registered serializer only reconstructs data (images, audio,
# etc.). Everything else -- "Op" and any unknown type -- routes through a
# code-loading path in `_decode_custom_obj`, so a client that forbids unsafe
# decode (server-side workers) refuses them. Kept in sync with KNOWN_TYPES minus
# "Op" by `test_safe_custom_weave_types_in_sync`; a new KNOWN_TYPE fails that test
# until it is consciously classified here.
SAFE_CUSTOM_WEAVE_TYPES = frozenset(
{
"PIL.Image.Image",
"wave.Wave_read",
"weave.type_handlers.Audio.audio.Audio",
"datetime.datetime",
"rich.markdown.Markdown",
"moviepy.video.VideoClip.VideoClip",
"weave.type_wrappers.Content.content.Content",
}
)


class UnsafeDeserializationError(Exception):
"""Raised when a client that forbids unsafe decode is asked to reconstruct a
code-bearing custom object (`Op`, or any type that falls back to `load_op`).
"""


def is_safe_to_decode(
weave_type_id: str, load_op_uri: str | None, *, allow_unsafe: bool
) -> bool:
"""Whether reconstructing this custom type is permitted.

`allow_unsafe` is the active client's policy. When False (server-side workers)
only registered data-only serializers with no `load_op` may be reconstructed;
a non-null `load_op` runs via the fallback path even for known types, so it is
never safe.
"""
if allow_unsafe:
return True
return weave_type_id in SAFE_CUSTOM_WEAVE_TYPES and load_op_uri is None


def encode_custom_obj(obj: Any) -> EncodedCustomObjDict | None:
serializer = get_serializer_for_obj(obj)
Expand Down Expand Up @@ -193,6 +238,15 @@ def _decode_custom_obj(
load_instance_op_uri: str | None = None,
) -> Any:
type_ = weave_type["type"]

client = get_weave_client()
allow_unsafe = client is None or client.allow_unsafe_custom_obj_decode
if not is_safe_to_decode(type_, load_instance_op_uri, allow_unsafe=allow_unsafe):
raise UnsafeDeserializationError(
f"Refusing to reconstruct custom object of type `{type_}`: the active "
f"client does not allow deserializing code-bearing custom objects."
)

found_serializer = False

# First, try to load the object using a known serializer
Expand Down
4 changes: 4 additions & 0 deletions weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,10 @@ def __init__(
self.future_executor_fastlane = FutureExecutor(max_workers=parallelism_upload)
self.ensure_project_exists = ensure_project_exists

# Server-side workers (eval, scoring) flip this off so deserialization
# never reconstructs code-bearing custom objects. See custom_objs.py.
self.allow_unsafe_custom_obj_decode = True

if ensure_project_exists:
resp = self.server.ensure_project_exists(entity, project)
# Set Client project name with updated project name
Expand Down
41 changes: 0 additions & 41 deletions weave/trace_server/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,44 +199,3 @@ def validate_alias_name(name: str) -> None:
)
if name in _RESERVED_ALIAS_NAMES:
raise ValueError(f"alias name '{name}' is reserved")


# --- Server-side deserialization safety ---


class UnsafePayloadError(ValueError):
"""Raised when a payload contains types that are not safe for server-side deserialization."""


def assert_safe_payload(value: Any, path: str = "root") -> None:
"""Recursively walk a raw serialized payload and reject CustomWeaveType nodes.

CustomWeaveType payloads can trigger code execution during deserialization
(e.g. Op types call __import__ on user-uploaded Python files, and unknown
types fall back to a load_op path that does the same). None of these should
be deserialized in a server-side worker process.

https://coreweave.atlassian.net/browse/VULNMGMT-1007
"""
if isinstance(value, list):
for idx, item in enumerate(value):
assert_safe_payload(item, f"{path}[{idx}]")
return

if not isinstance(value, dict):
return

if value.get("_type") == "CustomWeaveType":
weave_type = value.get("weave_type")
custom_type = (
weave_type.get("type", "unknown")
if isinstance(weave_type, dict)
else "unknown"
)
raise UnsafePayloadError(
f"Server-side deserialization does not allow CustomWeaveType payloads "
f"({custom_type} at {path})"
)

for key, item in value.items():
assert_safe_payload(item, f"{path}.{key}")
Loading
Loading