|
34 | 34 | OpenAIChatGenerator, |
35 | 35 | _check_finish_reason, |
36 | 36 | _convert_chat_completion_chunk_to_streaming_chunk, |
| 37 | + _make_schema_strict, |
37 | 38 | ) |
38 | 39 | from haystack.components.generators.utils import print_streaming_chunk |
39 | 40 | from haystack.dataclasses import ( |
@@ -1871,3 +1872,302 @@ def test_convert_usage_chunk_to_streaming_chunk(self): |
1871 | 1872 | assert result.tool_call_result is None |
1872 | 1873 | assert result.meta["model"] == "gpt-5-mini" |
1873 | 1874 | assert result.meta["received_at"] is not None |
| 1875 | + |
| 1876 | + |
| 1877 | +class TestMakeSchemaStrict: |
| 1878 | + def test_flat_object(self): |
| 1879 | + schema = {"type": "object", "properties": {"name": {"type": "string"}}} |
| 1880 | + result = _make_schema_strict(schema) |
| 1881 | + assert result == { |
| 1882 | + "type": "object", |
| 1883 | + "properties": {"name": {"type": "string"}}, |
| 1884 | + "additionalProperties": False, |
| 1885 | + "required": ["name"], |
| 1886 | + } |
| 1887 | + |
| 1888 | + def test_nested_object(self): |
| 1889 | + schema = { |
| 1890 | + "type": "object", |
| 1891 | + "properties": { |
| 1892 | + "person": {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}} |
| 1893 | + }, |
| 1894 | + } |
| 1895 | + result = _make_schema_strict(schema) |
| 1896 | + assert result == { |
| 1897 | + "type": "object", |
| 1898 | + "properties": { |
| 1899 | + "person": { |
| 1900 | + "type": "object", |
| 1901 | + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, |
| 1902 | + "additionalProperties": False, |
| 1903 | + "required": ["name", "age"], |
| 1904 | + } |
| 1905 | + }, |
| 1906 | + "additionalProperties": False, |
| 1907 | + "required": ["person"], |
| 1908 | + } |
| 1909 | + |
| 1910 | + def test_defs_and_ref(self): |
| 1911 | + schema = { |
| 1912 | + "type": "object", |
| 1913 | + "properties": {"address": {"$ref": "#/$defs/Address"}}, |
| 1914 | + "$defs": { |
| 1915 | + "Address": {"type": "object", "properties": {"street": {"type": "string"}, "city": {"type": "string"}}} |
| 1916 | + }, |
| 1917 | + } |
| 1918 | + result = _make_schema_strict(schema) |
| 1919 | + assert result == { |
| 1920 | + "type": "object", |
| 1921 | + "properties": {"address": {"$ref": "#/$defs/Address"}}, |
| 1922 | + "$defs": { |
| 1923 | + "Address": { |
| 1924 | + "type": "object", |
| 1925 | + "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, |
| 1926 | + "additionalProperties": False, |
| 1927 | + "required": ["street", "city"], |
| 1928 | + } |
| 1929 | + }, |
| 1930 | + "additionalProperties": False, |
| 1931 | + "required": ["address"], |
| 1932 | + } |
| 1933 | + |
| 1934 | + def test_array_items(self): |
| 1935 | + schema = { |
| 1936 | + "type": "object", |
| 1937 | + "properties": { |
| 1938 | + "people": {"type": "array", "items": {"type": "object", "properties": {"name": {"type": "string"}}}} |
| 1939 | + }, |
| 1940 | + } |
| 1941 | + result = _make_schema_strict(schema) |
| 1942 | + assert result == { |
| 1943 | + "type": "object", |
| 1944 | + "properties": { |
| 1945 | + "people": { |
| 1946 | + "type": "array", |
| 1947 | + "items": { |
| 1948 | + "type": "object", |
| 1949 | + "properties": {"name": {"type": "string"}}, |
| 1950 | + "additionalProperties": False, |
| 1951 | + "required": ["name"], |
| 1952 | + }, |
| 1953 | + } |
| 1954 | + }, |
| 1955 | + "additionalProperties": False, |
| 1956 | + "required": ["people"], |
| 1957 | + } |
| 1958 | + |
| 1959 | + def test_anyof(self): |
| 1960 | + schema = { |
| 1961 | + "type": "object", |
| 1962 | + "properties": { |
| 1963 | + "value": {"anyOf": [{"type": "string"}, {"type": "object", "properties": {"x": {"type": "integer"}}}]} |
| 1964 | + }, |
| 1965 | + } |
| 1966 | + result = _make_schema_strict(schema) |
| 1967 | + assert result == { |
| 1968 | + "type": "object", |
| 1969 | + "properties": { |
| 1970 | + "value": { |
| 1971 | + "anyOf": [ |
| 1972 | + {"type": "string"}, |
| 1973 | + { |
| 1974 | + "type": "object", |
| 1975 | + "properties": {"x": {"type": "integer"}}, |
| 1976 | + "additionalProperties": False, |
| 1977 | + "required": ["x"], |
| 1978 | + }, |
| 1979 | + ] |
| 1980 | + } |
| 1981 | + }, |
| 1982 | + "additionalProperties": False, |
| 1983 | + "required": ["value"], |
| 1984 | + } |
| 1985 | + |
| 1986 | + def test_does_not_mutate_original(self): |
| 1987 | + schema = {"type": "object", "properties": {"a": {"type": "string"}}} |
| 1988 | + result = _make_schema_strict(schema) |
| 1989 | + assert "additionalProperties" not in schema |
| 1990 | + assert "required" not in schema |
| 1991 | + assert result == { |
| 1992 | + "type": "object", |
| 1993 | + "properties": {"a": {"type": "string"}}, |
| 1994 | + "additionalProperties": False, |
| 1995 | + "required": ["a"], |
| 1996 | + } |
| 1997 | + |
| 1998 | + def test_preserves_existing_required(self): |
| 1999 | + schema = { |
| 2000 | + "type": "object", |
| 2001 | + "properties": {"a": {"type": "string"}, "b": {"type": "integer"}}, |
| 2002 | + "required": ["a"], |
| 2003 | + } |
| 2004 | + result = _make_schema_strict(schema) |
| 2005 | + assert result == { |
| 2006 | + "type": "object", |
| 2007 | + "properties": {"a": {"type": "string"}, "b": {"type": "integer"}}, |
| 2008 | + "additionalProperties": False, |
| 2009 | + "required": ["a", "b"], |
| 2010 | + } |
| 2011 | + |
| 2012 | + def test_complex_schema_with_defs_and_combinators(self): |
| 2013 | + schema = { |
| 2014 | + "type": "object", |
| 2015 | + "properties": { |
| 2016 | + "messages": {"type": "array", "items": {"$ref": "#/$defs/ChatMessage"}}, |
| 2017 | + "config": { |
| 2018 | + "oneOf": [ |
| 2019 | + {"type": "null"}, |
| 2020 | + { |
| 2021 | + "type": "object", |
| 2022 | + "properties": {"temperature": {"type": "number"}, "max_tokens": {"type": "integer"}}, |
| 2023 | + }, |
| 2024 | + ] |
| 2025 | + }, |
| 2026 | + }, |
| 2027 | + "$defs": { |
| 2028 | + "ChatMessage": { |
| 2029 | + "type": "object", |
| 2030 | + "properties": { |
| 2031 | + "role": {"type": "string"}, |
| 2032 | + "content": {"anyOf": [{"type": "string"}, {"type": "null"}]}, |
| 2033 | + "meta": { |
| 2034 | + "type": "object", |
| 2035 | + "properties": { |
| 2036 | + "model": {"type": "string"}, |
| 2037 | + "usage": { |
| 2038 | + "type": "object", |
| 2039 | + "properties": { |
| 2040 | + "prompt_tokens": {"type": "integer"}, |
| 2041 | + "completion_tokens": {"type": "integer"}, |
| 2042 | + }, |
| 2043 | + }, |
| 2044 | + }, |
| 2045 | + }, |
| 2046 | + }, |
| 2047 | + } |
| 2048 | + }, |
| 2049 | + } |
| 2050 | + result = _make_schema_strict(schema) |
| 2051 | + assert result == { |
| 2052 | + "type": "object", |
| 2053 | + "properties": { |
| 2054 | + "messages": {"type": "array", "items": {"$ref": "#/$defs/ChatMessage"}}, |
| 2055 | + "config": { |
| 2056 | + "oneOf": [ |
| 2057 | + {"type": "null"}, |
| 2058 | + { |
| 2059 | + "type": "object", |
| 2060 | + "properties": {"temperature": {"type": "number"}, "max_tokens": {"type": "integer"}}, |
| 2061 | + "additionalProperties": False, |
| 2062 | + "required": ["temperature", "max_tokens"], |
| 2063 | + }, |
| 2064 | + ] |
| 2065 | + }, |
| 2066 | + }, |
| 2067 | + "$defs": { |
| 2068 | + "ChatMessage": { |
| 2069 | + "type": "object", |
| 2070 | + "properties": { |
| 2071 | + "role": {"type": "string"}, |
| 2072 | + "content": {"anyOf": [{"type": "string"}, {"type": "null"}]}, |
| 2073 | + "meta": { |
| 2074 | + "type": "object", |
| 2075 | + "properties": { |
| 2076 | + "model": {"type": "string"}, |
| 2077 | + "usage": { |
| 2078 | + "type": "object", |
| 2079 | + "properties": { |
| 2080 | + "prompt_tokens": {"type": "integer"}, |
| 2081 | + "completion_tokens": {"type": "integer"}, |
| 2082 | + }, |
| 2083 | + "additionalProperties": False, |
| 2084 | + "required": ["prompt_tokens", "completion_tokens"], |
| 2085 | + }, |
| 2086 | + }, |
| 2087 | + "additionalProperties": False, |
| 2088 | + "required": ["model", "usage"], |
| 2089 | + }, |
| 2090 | + }, |
| 2091 | + "additionalProperties": False, |
| 2092 | + "required": ["role", "content", "meta"], |
| 2093 | + } |
| 2094 | + }, |
| 2095 | + "additionalProperties": False, |
| 2096 | + "required": ["messages", "config"], |
| 2097 | + } |
| 2098 | + |
| 2099 | + def test_prepare_api_call_strict_nested_tool(self): |
| 2100 | + nested_tool = Tool( |
| 2101 | + name="create_person", |
| 2102 | + description="Create a person record", |
| 2103 | + parameters={ |
| 2104 | + "type": "object", |
| 2105 | + "properties": { |
| 2106 | + "name": {"type": "string"}, |
| 2107 | + "address": { |
| 2108 | + "type": "object", |
| 2109 | + "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, |
| 2110 | + }, |
| 2111 | + }, |
| 2112 | + "required": ["name"], |
| 2113 | + }, |
| 2114 | + function=lambda name, address: f"{name} at {address}", |
| 2115 | + ) |
| 2116 | + |
| 2117 | + component = OpenAIChatGenerator(api_key=Secret.from_token("test-key"), tools_strict=True) |
| 2118 | + api_args = component._prepare_api_call(messages=[ChatMessage.from_user("test")], tools=[nested_tool]) |
| 2119 | + |
| 2120 | + tool_def = api_args["tools"][0]["function"] |
| 2121 | + assert tool_def["strict"] is True |
| 2122 | + assert tool_def["parameters"] == { |
| 2123 | + "type": "object", |
| 2124 | + "properties": { |
| 2125 | + "name": {"type": "string"}, |
| 2126 | + "address": { |
| 2127 | + "type": "object", |
| 2128 | + "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, |
| 2129 | + "additionalProperties": False, |
| 2130 | + "required": ["street", "city"], |
| 2131 | + }, |
| 2132 | + }, |
| 2133 | + "additionalProperties": False, |
| 2134 | + "required": ["name", "address"], |
| 2135 | + } |
| 2136 | + |
| 2137 | + @pytest.mark.skipif( |
| 2138 | + not os.environ.get("OPENAI_API_KEY", None), |
| 2139 | + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", |
| 2140 | + ) |
| 2141 | + @pytest.mark.integration |
| 2142 | + def test_live_run_strict_nested_tool(self): |
| 2143 | + tool = Tool( |
| 2144 | + name="create_person", |
| 2145 | + description="Create a person record with an address", |
| 2146 | + parameters={ |
| 2147 | + "type": "object", |
| 2148 | + "properties": { |
| 2149 | + "name": {"type": "string", "description": "Full name"}, |
| 2150 | + "address": { |
| 2151 | + "type": "object", |
| 2152 | + "properties": { |
| 2153 | + "street": {"type": "string", "description": "Street address"}, |
| 2154 | + "city": {"type": "string", "description": "City name"}, |
| 2155 | + }, |
| 2156 | + }, |
| 2157 | + }, |
| 2158 | + }, |
| 2159 | + function=lambda name, address: f"{name} at {address}", |
| 2160 | + ) |
| 2161 | + component = OpenAIChatGenerator(model="gpt-4.1-nano", tools_strict=True) |
| 2162 | + results = component.run( |
| 2163 | + messages=[ChatMessage.from_user("Create a person named John at 123 Main St, Springfield")], tools=[tool] |
| 2164 | + ) |
| 2165 | + assert len(results["replies"]) == 1 |
| 2166 | + message = results["replies"][0] |
| 2167 | + assert message.tool_calls |
| 2168 | + tool_call = message.tool_call |
| 2169 | + assert tool_call.tool_name == "create_person" |
| 2170 | + assert "name" in tool_call.arguments |
| 2171 | + assert "address" in tool_call.arguments |
| 2172 | + assert "street" in tool_call.arguments["address"] |
| 2173 | + assert "city" in tool_call.arguments["address"] |
0 commit comments