Skip to content

Commit 918c1dc

Browse files
gtarpenningclaude
andcommitted
fix(weave): refuse code-bearing custom objects on server-side decode
Server-side workers (the evaluate-model worker) reconstruct user-supplied objects. Gate custom-object deserialization on a per-client policy (WeaveClient.allow_unsafe_custom_obj_decode, default True) so workers can refuse to reconstruct Op / load_op-backed custom types at decode time, including dataset rows materialized lazily during evaluation. https://coreweave.atlassian.net/browse/WB-34909 Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent c6b4be6 commit 918c1dc

7 files changed

Lines changed: 209 additions & 131 deletions

File tree

AGENTS.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ _Important:_ For OpenAI Codex agents (most likely you!), your environment does n
8181
- Testing is managed by `nox` with multiple shards for different Python versions
8282
- Each shard represents specific package configurations
8383

84+
### Server fixtures (do NOT hand-roll fake servers)
85+
86+
- **Never create a `_FakeServer`, stub, or mock `TraceServerInterface`** to test
87+
server-side logic. Use the existing `client` fixture (gives `client.server` +
88+
`client.project_id`) or the `trace_server` fixture, which run against a real
89+
SQLite/ClickHouse backend. Build inputs with the real APIs
90+
(`obj_create`, `table_create`, etc.). Mock only external services we don't own.
91+
8492
### Key Test Shards
8593

8694
Focus on these primary test shards:

tests/trace_server/test_trace_server_evaluation_apis.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tests.trace.util import client_is_sqlite
1010
from tests.trace_server.completions_util import with_simple_mock_litellm_completion
1111
from weave.trace.refs import ObjectRef
12+
from weave.trace.serialization.custom_objs import UnsafeDeserializationError
1213
from weave.trace.weave_client import WeaveClient, generate_id
1314
from weave.trace_server.errors import InvalidRequest
1415
from weave.trace_server.interface.query import Query
@@ -28,6 +29,7 @@
2829
EvaluationStatusNotFound,
2930
EvaluationStatusReq,
3031
EvaluationStatusRunning,
32+
FileCreateReq,
3133
GenAISpanRef,
3234
ObjCreateReq,
3335
PredictionCreateReq,
@@ -406,6 +408,93 @@ def evaluate_model_wrapped(req: EvaluateModelReq):
406408
}
407409

408410

411+
# The guard raises inside the lazy row-decode threadpool, which logs the failure
412+
# at ERROR before it propagates out of asyncio.run; that log is the expected path.
413+
@pytest.mark.disable_logging_error_check
414+
def test_evaluate_model_rejects_unsafe_dataset_row(client):
415+
"""An Op node in a dataset row must be refused at decode time, not loaded and
416+
executed (WB-34909).
417+
418+
The evaluation and dataset objects carry no custom types themselves; the Op
419+
CustomWeaveType lives in a table row that is only fetched and deserialized
420+
lazily during evaluation. The worker disables unsafe custom-object decode, so
421+
materializing that row raises instead of importing the code.
422+
"""
423+
project_id = client.project_id
424+
entity, project = from_project_id(project_id)
425+
server = client.server
426+
427+
def _obj(object_id: str, val: dict, builtin: str | None = None) -> str:
428+
obj = {"project_id": project_id, "object_id": object_id, "val": val}
429+
if builtin is not None:
430+
obj["builtin_object_class"] = builtin
431+
res = server.obj_create(ObjCreateReq.model_validate({"obj": obj}))
432+
return ObjectRef(
433+
entity=entity, project=project, name=object_id, _digest=res.digest
434+
).uri
435+
436+
# A real file so the lazy file fetch succeeds; the decode guard, not a missing
437+
# file, is what must stop the Op row from being reconstructed.
438+
file_res = server.file_create(
439+
FileCreateReq(project_id=project_id, name="obj.py", content=b"print('hi')")
440+
)
441+
table_res = server.table_create(
442+
TableCreateReq.model_validate(
443+
{
444+
"table": {
445+
"project_id": project_id,
446+
"rows": [
447+
{"input": "ok"},
448+
{
449+
"input": {
450+
"_type": "CustomWeaveType",
451+
"weave_type": {"type": "Op"},
452+
"files": {"obj.py": file_res.digest},
453+
"load_op": None,
454+
}
455+
},
456+
],
457+
}
458+
}
459+
)
460+
)
461+
dataset_ref = _obj(
462+
"dataset_with_custom_row",
463+
{
464+
"_type": "Dataset",
465+
"_class_name": "Dataset",
466+
"_bases": ["BaseModel", "Object", "Dataset"],
467+
"rows": f"weave:///{project_id}/table/{table_res.digest}",
468+
},
469+
)
470+
evaluation_ref = _obj(
471+
"eval_with_custom_row",
472+
{
473+
"_type": "Evaluation",
474+
"_class_name": "Evaluation",
475+
"_bases": ["BaseModel", "Object", "Evaluation"],
476+
"dataset": dataset_ref,
477+
"scorers": None,
478+
},
479+
)
480+
model_ref = _obj(
481+
"valid_model",
482+
{"llm_model_id": "gpt-4o-mini", "default_params": {}},
483+
builtin="LLMStructuredCompletionModel",
484+
)
485+
486+
with pytest.raises(UnsafeDeserializationError):
487+
evaluate_model_worker.evaluate_model(
488+
evaluate_model_worker.EvaluateModelArgs(
489+
project_id=project_id,
490+
evaluation_ref=evaluation_ref,
491+
model_ref=model_ref,
492+
wb_user_id=entity,
493+
evaluation_call_id=generate_id(),
494+
)
495+
)
496+
497+
409498
def test_eval_results_query_basic(client):
410499
project_id = client.project_id
411500
entity, project = from_project_id(project_id)
Lines changed: 37 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,46 @@
11
import pytest
22

3-
from weave.trace_server.validation import (
4-
UnsafePayloadError,
5-
assert_safe_payload,
3+
from weave.trace.serialization.custom_objs import (
4+
KNOWN_TYPES,
5+
SAFE_CUSTOM_WEAVE_TYPES,
6+
is_safe_to_decode,
7+
)
8+
from weave.trace_server.workers.evaluate_model_worker.evaluate_model_worker import (
9+
_assert_object_ref,
610
)
711

812

9-
def test_assert_safe_payload():
10-
# Safe payloads pass
11-
assert_safe_payload({"name": "test", "value": 42})
12-
assert_safe_payload({"nested": {"list": [1, "two", {"three": 3}]}})
13-
assert_safe_payload("just a string ref")
14-
assert_safe_payload(None)
15-
assert_safe_payload([1, 2, 3])
16-
assert_safe_payload({"_type": "ObjectRecord", "name": "test"})
17-
18-
# ALL CustomWeaveType payloads rejected — Op, unknown, and safe subtypes alike
19-
with pytest.raises(UnsafePayloadError):
20-
assert_safe_payload(
21-
{
22-
"_type": "CustomWeaveType",
23-
"weave_type": {"type": "Op"},
24-
"files": {"obj.py": "abc123"},
25-
}
26-
)
27-
28-
with pytest.raises(UnsafePayloadError):
29-
assert_safe_payload(
30-
{
31-
"_type": "CustomWeaveType",
32-
"weave_type": {"type": "PIL.Image.Image"},
33-
"files": {"image.png": "abc123"},
34-
}
35-
)
13+
@pytest.mark.parametrize(
14+
("type_id", "load_op", "allow_unsafe", "expected"),
15+
[
16+
# allow_unsafe (normal client) -> anything decodes.
17+
("Op", None, True, True),
18+
("TotallyMadeUp", None, True, True),
19+
# Worker client (allow_unsafe=False): only data-only serializers, no load_op.
20+
("Op", None, False, False),
21+
("TotallyMadeUp", None, False, False),
22+
("PIL.Image.Image", None, False, True),
23+
("weave.type_wrappers.Content.content.Content", None, False, True),
24+
# A load_op routes through the fallback code path even for known types.
25+
("PIL.Image.Image", "weave:///e/p/op/x:1", False, False),
26+
],
27+
)
28+
def test_is_safe_to_decode(type_id, load_op, allow_unsafe, expected):
29+
assert is_safe_to_decode(type_id, load_op, allow_unsafe=allow_unsafe) is expected
3630

37-
with pytest.raises(UnsafePayloadError):
38-
assert_safe_payload(
39-
{
40-
"_type": "CustomWeaveType",
41-
"weave_type": {"type": "TotallyMadeUpType"},
42-
"load_op": "weave:///entity/project/object/evil:abc123",
43-
}
44-
)
4531

46-
# Nested in a list
47-
with pytest.raises(UnsafePayloadError):
48-
assert_safe_payload(
49-
{
50-
"scorers": [
51-
{
52-
"_type": "CustomWeaveType",
53-
"weave_type": {"type": "Op"},
54-
}
55-
],
56-
}
57-
)
32+
def test_safe_custom_weave_types_in_sync():
33+
# Every known custom type must be classified: safe data-only serializers go in
34+
# SAFE_CUSTOM_WEAVE_TYPES, and "Op" is the lone code-loading type. A newly
35+
# added KNOWN_TYPE fails here until it is consciously placed on one side.
36+
assert SAFE_CUSTOM_WEAVE_TYPES | {"Op"} == set(KNOWN_TYPES)
5837

59-
# Deeply nested in dicts
60-
with pytest.raises(UnsafePayloadError):
61-
assert_safe_payload(
62-
{
63-
"a": {
64-
"b": {
65-
"c": {
66-
"_type": "CustomWeaveType",
67-
"weave_type": {"type": "Op"},
68-
}
69-
}
70-
},
71-
}
72-
)
7338

74-
# Missing/malformed weave_type still rejected
75-
with pytest.raises(UnsafePayloadError):
76-
assert_safe_payload({"_type": "CustomWeaveType"})
77-
with pytest.raises(UnsafePayloadError):
78-
assert_safe_payload({"_type": "CustomWeaveType", "weave_type": "not a dict"})
39+
def test_assert_object_ref_rejects_op_and_non_object_refs():
40+
# Op ref would be loaded/executed by client.get; table/other refs aren't models.
41+
with pytest.raises(TypeError):
42+
_assert_object_ref("weave:///ent/proj/op/some_op:abc123", "evaluation_ref")
43+
with pytest.raises(TypeError):
44+
_assert_object_ref("weave:///ent/proj/table/abc123", "evaluation_ref")
45+
# A plain object ref is accepted.
46+
_assert_object_ref("weave:///ent/proj/object/MyEval:abc123", "evaluation_ref")

weave/trace/serialization/custom_objs.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from collections.abc import Mapping
55
from typing import Any, Literal, TypedDict
66

7-
from weave.trace.context.weave_client_context import require_weave_client
7+
from weave.trace.context.weave_client_context import (
8+
get_weave_client,
9+
require_weave_client,
10+
)
811
from weave.trace.op import op
912
from weave.trace.op_protocol import Op
1013
from weave.trace.refs import ObjectRef
@@ -65,6 +68,48 @@ class EncodedCustomObjDict(
6568
"weave.type_wrappers.Content.content.Content",
6669
}
6770

71+
# The one type whose serializer loads code (imports a user-uploaded `.py`).
72+
OP_CUSTOM_WEAVE_TYPE = "Op"
73+
74+
# Custom types whose registered serializer only reconstructs data (images, audio,
75+
# etc.). Everything else -- "Op" and any unknown type -- routes through a
76+
# code-loading path in `_decode_custom_obj`, so a client that forbids unsafe
77+
# decode (server-side workers) refuses them. Kept in sync with KNOWN_TYPES minus
78+
# "Op" by `test_safe_custom_weave_types_in_sync`; a new KNOWN_TYPE fails that test
79+
# until it is consciously classified here.
80+
SAFE_CUSTOM_WEAVE_TYPES = frozenset(
81+
{
82+
"PIL.Image.Image",
83+
"wave.Wave_read",
84+
"weave.type_handlers.Audio.audio.Audio",
85+
"datetime.datetime",
86+
"rich.markdown.Markdown",
87+
"moviepy.video.VideoClip.VideoClip",
88+
"weave.type_wrappers.Content.content.Content",
89+
}
90+
)
91+
92+
93+
class UnsafeDeserializationError(Exception):
94+
"""Raised when a client that forbids unsafe decode is asked to reconstruct a
95+
code-bearing custom object (`Op`, or any type that falls back to `load_op`).
96+
"""
97+
98+
99+
def is_safe_to_decode(
100+
weave_type_id: str, load_op_uri: str | None, *, allow_unsafe: bool
101+
) -> bool:
102+
"""Whether reconstructing this custom type is permitted.
103+
104+
`allow_unsafe` is the active client's policy. When False (server-side workers)
105+
only registered data-only serializers with no `load_op` may be reconstructed;
106+
a non-null `load_op` runs via the fallback path even for known types, so it is
107+
never safe.
108+
"""
109+
if allow_unsafe:
110+
return True
111+
return weave_type_id in SAFE_CUSTOM_WEAVE_TYPES and load_op_uri is None
112+
68113

69114
def encode_custom_obj(obj: Any) -> EncodedCustomObjDict | None:
70115
serializer = get_serializer_for_obj(obj)
@@ -193,6 +238,15 @@ def _decode_custom_obj(
193238
load_instance_op_uri: str | None = None,
194239
) -> Any:
195240
type_ = weave_type["type"]
241+
242+
client = get_weave_client()
243+
allow_unsafe = client is None or client.allow_unsafe_custom_obj_decode
244+
if not is_safe_to_decode(type_, load_instance_op_uri, allow_unsafe=allow_unsafe):
245+
raise UnsafeDeserializationError(
246+
f"Refusing to reconstruct custom object of type `{type_}`: the active "
247+
f"client does not allow deserializing code-bearing custom objects."
248+
)
249+
196250
found_serializer = False
197251

198252
# First, try to load the object using a known serializer

weave/trace/weave_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,10 @@ def __init__(
398398
self.future_executor_fastlane = FutureExecutor(max_workers=parallelism_upload)
399399
self.ensure_project_exists = ensure_project_exists
400400

401+
# Server-side workers (eval, scoring) flip this off so deserialization
402+
# never reconstructs code-bearing custom objects. See custom_objs.py.
403+
self.allow_unsafe_custom_obj_decode = True
404+
401405
if ensure_project_exists:
402406
resp = self.server.ensure_project_exists(entity, project)
403407
# Set Client project name with updated project name

weave/trace_server/validation.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -199,44 +199,3 @@ def validate_alias_name(name: str) -> None:
199199
)
200200
if name in _RESERVED_ALIAS_NAMES:
201201
raise ValueError(f"alias name '{name}' is reserved")
202-
203-
204-
# --- Server-side deserialization safety ---
205-
206-
207-
class UnsafePayloadError(ValueError):
208-
"""Raised when a payload contains types that are not safe for server-side deserialization."""
209-
210-
211-
def assert_safe_payload(value: Any, path: str = "root") -> None:
212-
"""Recursively walk a raw serialized payload and reject CustomWeaveType nodes.
213-
214-
CustomWeaveType payloads can trigger code execution during deserialization
215-
(e.g. Op types call __import__ on user-uploaded Python files, and unknown
216-
types fall back to a load_op path that does the same). None of these should
217-
be deserialized in a server-side worker process.
218-
219-
https://coreweave.atlassian.net/browse/VULNMGMT-1007
220-
"""
221-
if isinstance(value, list):
222-
for idx, item in enumerate(value):
223-
assert_safe_payload(item, f"{path}[{idx}]")
224-
return
225-
226-
if not isinstance(value, dict):
227-
return
228-
229-
if value.get("_type") == "CustomWeaveType":
230-
weave_type = value.get("weave_type")
231-
custom_type = (
232-
weave_type.get("type", "unknown")
233-
if isinstance(weave_type, dict)
234-
else "unknown"
235-
)
236-
raise UnsafePayloadError(
237-
f"Server-side deserialization does not allow CustomWeaveType payloads "
238-
f"({custom_type} at {path})"
239-
)
240-
241-
for key, item in value.items():
242-
assert_safe_payload(item, f"{path}.{key}")

0 commit comments

Comments
 (0)