|
1 | 1 | import pytest |
2 | 2 |
|
3 | | -from weave.trace_server.validation import ( |
4 | | - UnsafePayloadError, |
5 | | - assert_safe_payload, |
| 3 | +from weave.trace.serialization.custom_objs import ( |
| 4 | + KNOWN_TYPES, |
| 5 | + SAFE_CUSTOM_WEAVE_TYPES, |
| 6 | + is_safe_to_decode, |
| 7 | +) |
| 8 | +from weave.trace_server.workers.evaluate_model_worker.evaluate_model_worker import ( |
| 9 | + _assert_object_ref, |
6 | 10 | ) |
7 | 11 |
|
8 | 12 |
|
9 | | -def test_assert_safe_payload(): |
10 | | - # Safe payloads pass |
11 | | - assert_safe_payload({"name": "test", "value": 42}) |
12 | | - assert_safe_payload({"nested": {"list": [1, "two", {"three": 3}]}}) |
13 | | - assert_safe_payload("just a string ref") |
14 | | - assert_safe_payload(None) |
15 | | - assert_safe_payload([1, 2, 3]) |
16 | | - assert_safe_payload({"_type": "ObjectRecord", "name": "test"}) |
17 | | - |
18 | | - # ALL CustomWeaveType payloads rejected — Op, unknown, and safe subtypes alike |
19 | | - with pytest.raises(UnsafePayloadError): |
20 | | - assert_safe_payload( |
21 | | - { |
22 | | - "_type": "CustomWeaveType", |
23 | | - "weave_type": {"type": "Op"}, |
24 | | - "files": {"obj.py": "abc123"}, |
25 | | - } |
26 | | - ) |
27 | | - |
28 | | - with pytest.raises(UnsafePayloadError): |
29 | | - assert_safe_payload( |
30 | | - { |
31 | | - "_type": "CustomWeaveType", |
32 | | - "weave_type": {"type": "PIL.Image.Image"}, |
33 | | - "files": {"image.png": "abc123"}, |
34 | | - } |
35 | | - ) |
| 13 | +@pytest.mark.parametrize( |
| 14 | + ("type_id", "load_op", "allow_unsafe", "expected"), |
| 15 | + [ |
| 16 | + # allow_unsafe (normal client) -> anything decodes. |
| 17 | + ("Op", None, True, True), |
| 18 | + ("TotallyMadeUp", None, True, True), |
| 19 | + # Worker client (allow_unsafe=False): only data-only serializers, no load_op. |
| 20 | + ("Op", None, False, False), |
| 21 | + ("TotallyMadeUp", None, False, False), |
| 22 | + ("PIL.Image.Image", None, False, True), |
| 23 | + ("weave.type_wrappers.Content.content.Content", None, False, True), |
| 24 | + # A load_op routes through the fallback code path even for known types. |
| 25 | + ("PIL.Image.Image", "weave:///e/p/op/x:1", False, False), |
| 26 | + ], |
| 27 | +) |
| 28 | +def test_is_safe_to_decode(type_id, load_op, allow_unsafe, expected): |
| 29 | + assert is_safe_to_decode(type_id, load_op, allow_unsafe=allow_unsafe) is expected |
36 | 30 |
|
37 | | - with pytest.raises(UnsafePayloadError): |
38 | | - assert_safe_payload( |
39 | | - { |
40 | | - "_type": "CustomWeaveType", |
41 | | - "weave_type": {"type": "TotallyMadeUpType"}, |
42 | | - "load_op": "weave:///entity/project/object/evil:abc123", |
43 | | - } |
44 | | - ) |
45 | 31 |
|
46 | | - # Nested in a list |
47 | | - with pytest.raises(UnsafePayloadError): |
48 | | - assert_safe_payload( |
49 | | - { |
50 | | - "scorers": [ |
51 | | - { |
52 | | - "_type": "CustomWeaveType", |
53 | | - "weave_type": {"type": "Op"}, |
54 | | - } |
55 | | - ], |
56 | | - } |
57 | | - ) |
| 32 | +def test_safe_custom_weave_types_in_sync(): |
| 33 | + # Every known custom type must be classified: safe data-only serializers go in |
| 34 | + # SAFE_CUSTOM_WEAVE_TYPES, and "Op" is the lone code-loading type. A newly |
| 35 | + # added KNOWN_TYPE fails here until it is consciously placed on one side. |
| 36 | + assert SAFE_CUSTOM_WEAVE_TYPES | {"Op"} == set(KNOWN_TYPES) |
58 | 37 |
|
59 | | - # Deeply nested in dicts |
60 | | - with pytest.raises(UnsafePayloadError): |
61 | | - assert_safe_payload( |
62 | | - { |
63 | | - "a": { |
64 | | - "b": { |
65 | | - "c": { |
66 | | - "_type": "CustomWeaveType", |
67 | | - "weave_type": {"type": "Op"}, |
68 | | - } |
69 | | - } |
70 | | - }, |
71 | | - } |
72 | | - ) |
73 | 38 |
|
74 | | - # Missing/malformed weave_type still rejected |
75 | | - with pytest.raises(UnsafePayloadError): |
76 | | - assert_safe_payload({"_type": "CustomWeaveType"}) |
77 | | - with pytest.raises(UnsafePayloadError): |
78 | | - assert_safe_payload({"_type": "CustomWeaveType", "weave_type": "not a dict"}) |
| 39 | +def test_assert_object_ref_rejects_op_and_non_object_refs(): |
| 40 | + # Op ref would be loaded/executed by client.get; table/other refs aren't models. |
| 41 | + with pytest.raises(TypeError): |
| 42 | + _assert_object_ref("weave:///ent/proj/op/some_op:abc123", "evaluation_ref") |
| 43 | + with pytest.raises(TypeError): |
| 44 | + _assert_object_ref("weave:///ent/proj/table/abc123", "evaluation_ref") |
| 45 | + # A plain object ref is accepted. |
| 46 | + _assert_object_ref("weave:///ent/proj/object/MyEval:abc123", "evaluation_ref") |
0 commit comments