|
23 | 23 |
|
24 | 24 | from openlayer import Openlayer, AsyncOpenlayer, APIResponseValidationError |
25 | 25 | from openlayer._types import Omit |
26 | | -from openlayer._utils import maybe_transform |
27 | 26 | from openlayer._models import BaseModel, FinalRequestOptions |
28 | | -from openlayer._constants import RAW_RESPONSE_HEADER |
29 | 27 | from openlayer._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError |
30 | 28 | from openlayer._base_client import ( |
31 | 29 | DEFAULT_TIMEOUT, |
|
35 | 33 | DefaultAsyncHttpxClient, |
36 | 34 | make_request_options, |
37 | 35 | ) |
38 | | -from openlayer.types.inference_pipelines.data_stream_params import DataStreamParams |
39 | 36 |
|
40 | 37 | from .utils import update_env |
41 | 38 |
|
@@ -724,82 +721,49 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str |
724 | 721 |
|
725 | 722 | @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) |
726 | 723 | @pytest.mark.respx(base_url=base_url) |
727 | | - def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 724 | + def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: Openlayer) -> None: |
728 | 725 | respx_mock.post("/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream").mock( |
729 | 726 | side_effect=httpx.TimeoutException("Test timeout error") |
730 | 727 | ) |
731 | 728 |
|
732 | 729 | with pytest.raises(APITimeoutError): |
733 | | - self.client.post( |
734 | | - "/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream", |
735 | | - body=cast( |
736 | | - object, |
737 | | - maybe_transform( |
738 | | - dict( |
739 | | - config={ |
740 | | - "input_variable_names": ["user_query"], |
741 | | - "output_column_name": "output", |
742 | | - "num_of_token_column_name": "tokens", |
743 | | - "cost_column_name": "cost", |
744 | | - "timestamp_column_name": "timestamp", |
745 | | - }, |
746 | | - rows=[ |
747 | | - { |
748 | | - "user_query": "what is the meaning of life?", |
749 | | - "output": "42", |
750 | | - "tokens": 7, |
751 | | - "cost": 0.02, |
752 | | - "timestamp": 1610000000, |
753 | | - } |
754 | | - ], |
755 | | - ), |
756 | | - DataStreamParams, |
757 | | - ), |
758 | | - ), |
759 | | - cast_to=httpx.Response, |
760 | | - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
761 | | - ) |
| 730 | + client.inference_pipelines.data.with_streaming_response.stream( |
| 731 | + inference_pipeline_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", |
| 732 | + config={"output_column_name": "output"}, |
| 733 | + rows=[ |
| 734 | + { |
| 735 | + "user_query": "bar", |
| 736 | + "output": "bar", |
| 737 | + "tokens": "bar", |
| 738 | + "cost": "bar", |
| 739 | + "timestamp": "bar", |
| 740 | + } |
| 741 | + ], |
| 742 | + ).__enter__() |
762 | 743 |
|
763 | 744 | assert _get_open_connections(self.client) == 0 |
764 | 745 |
|
765 | 746 | @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) |
766 | 747 | @pytest.mark.respx(base_url=base_url) |
767 | | - def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 748 | + def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: Openlayer) -> None: |
768 | 749 | respx_mock.post("/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream").mock( |
769 | 750 | return_value=httpx.Response(500) |
770 | 751 | ) |
771 | 752 |
|
772 | 753 | with pytest.raises(APIStatusError): |
773 | | - self.client.post( |
774 | | - "/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream", |
775 | | - body=cast( |
776 | | - object, |
777 | | - maybe_transform( |
778 | | - dict( |
779 | | - config={ |
780 | | - "input_variable_names": ["user_query"], |
781 | | - "output_column_name": "output", |
782 | | - "num_of_token_column_name": "tokens", |
783 | | - "cost_column_name": "cost", |
784 | | - "timestamp_column_name": "timestamp", |
785 | | - }, |
786 | | - rows=[ |
787 | | - { |
788 | | - "user_query": "what is the meaning of life?", |
789 | | - "output": "42", |
790 | | - "tokens": 7, |
791 | | - "cost": 0.02, |
792 | | - "timestamp": 1610000000, |
793 | | - } |
794 | | - ], |
795 | | - ), |
796 | | - DataStreamParams, |
797 | | - ), |
798 | | - ), |
799 | | - cast_to=httpx.Response, |
800 | | - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
801 | | - ) |
802 | | - |
| 754 | + client.inference_pipelines.data.with_streaming_response.stream( |
| 755 | + inference_pipeline_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", |
| 756 | + config={"output_column_name": "output"}, |
| 757 | + rows=[ |
| 758 | + { |
| 759 | + "user_query": "bar", |
| 760 | + "output": "bar", |
| 761 | + "tokens": "bar", |
| 762 | + "cost": "bar", |
| 763 | + "timestamp": "bar", |
| 764 | + } |
| 765 | + ], |
| 766 | + ).__enter__() |
803 | 767 | assert _get_open_connections(self.client) == 0 |
804 | 768 |
|
805 | 769 | @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) |
@@ -1652,82 +1616,53 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte |
1652 | 1616 |
|
1653 | 1617 | @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) |
1654 | 1618 | @pytest.mark.respx(base_url=base_url) |
1655 | | - async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 1619 | + async def test_retrying_timeout_errors_doesnt_leak( |
| 1620 | + self, respx_mock: MockRouter, async_client: AsyncOpenlayer |
| 1621 | + ) -> None: |
1656 | 1622 | respx_mock.post("/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream").mock( |
1657 | 1623 | side_effect=httpx.TimeoutException("Test timeout error") |
1658 | 1624 | ) |
1659 | 1625 |
|
1660 | 1626 | with pytest.raises(APITimeoutError): |
1661 | | - await self.client.post( |
1662 | | - "/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream", |
1663 | | - body=cast( |
1664 | | - object, |
1665 | | - maybe_transform( |
1666 | | - dict( |
1667 | | - config={ |
1668 | | - "input_variable_names": ["user_query"], |
1669 | | - "output_column_name": "output", |
1670 | | - "num_of_token_column_name": "tokens", |
1671 | | - "cost_column_name": "cost", |
1672 | | - "timestamp_column_name": "timestamp", |
1673 | | - }, |
1674 | | - rows=[ |
1675 | | - { |
1676 | | - "user_query": "what is the meaning of life?", |
1677 | | - "output": "42", |
1678 | | - "tokens": 7, |
1679 | | - "cost": 0.02, |
1680 | | - "timestamp": 1610000000, |
1681 | | - } |
1682 | | - ], |
1683 | | - ), |
1684 | | - DataStreamParams, |
1685 | | - ), |
1686 | | - ), |
1687 | | - cast_to=httpx.Response, |
1688 | | - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
1689 | | - ) |
| 1627 | + await async_client.inference_pipelines.data.with_streaming_response.stream( |
| 1628 | + inference_pipeline_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", |
| 1629 | + config={"output_column_name": "output"}, |
| 1630 | + rows=[ |
| 1631 | + { |
| 1632 | + "user_query": "bar", |
| 1633 | + "output": "bar", |
| 1634 | + "tokens": "bar", |
| 1635 | + "cost": "bar", |
| 1636 | + "timestamp": "bar", |
| 1637 | + } |
| 1638 | + ], |
| 1639 | + ).__aenter__() |
1690 | 1640 |
|
1691 | 1641 | assert _get_open_connections(self.client) == 0 |
1692 | 1642 |
|
1693 | 1643 | @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) |
1694 | 1644 | @pytest.mark.respx(base_url=base_url) |
1695 | | - async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 1645 | + async def test_retrying_status_errors_doesnt_leak( |
| 1646 | + self, respx_mock: MockRouter, async_client: AsyncOpenlayer |
| 1647 | + ) -> None: |
1696 | 1648 | respx_mock.post("/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream").mock( |
1697 | 1649 | return_value=httpx.Response(500) |
1698 | 1650 | ) |
1699 | 1651 |
|
1700 | 1652 | with pytest.raises(APIStatusError): |
1701 | | - await self.client.post( |
1702 | | - "/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream", |
1703 | | - body=cast( |
1704 | | - object, |
1705 | | - maybe_transform( |
1706 | | - dict( |
1707 | | - config={ |
1708 | | - "input_variable_names": ["user_query"], |
1709 | | - "output_column_name": "output", |
1710 | | - "num_of_token_column_name": "tokens", |
1711 | | - "cost_column_name": "cost", |
1712 | | - "timestamp_column_name": "timestamp", |
1713 | | - }, |
1714 | | - rows=[ |
1715 | | - { |
1716 | | - "user_query": "what is the meaning of life?", |
1717 | | - "output": "42", |
1718 | | - "tokens": 7, |
1719 | | - "cost": 0.02, |
1720 | | - "timestamp": 1610000000, |
1721 | | - } |
1722 | | - ], |
1723 | | - ), |
1724 | | - DataStreamParams, |
1725 | | - ), |
1726 | | - ), |
1727 | | - cast_to=httpx.Response, |
1728 | | - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
1729 | | - ) |
1730 | | - |
| 1653 | + await async_client.inference_pipelines.data.with_streaming_response.stream( |
| 1654 | + inference_pipeline_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", |
| 1655 | + config={"output_column_name": "output"}, |
| 1656 | + rows=[ |
| 1657 | + { |
| 1658 | + "user_query": "bar", |
| 1659 | + "output": "bar", |
| 1660 | + "tokens": "bar", |
| 1661 | + "cost": "bar", |
| 1662 | + "timestamp": "bar", |
| 1663 | + } |
| 1664 | + ], |
| 1665 | + ).__aenter__() |
1731 | 1666 | assert _get_open_connections(self.client) == 0 |
1732 | 1667 |
|
1733 | 1668 | @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) |
|
0 commit comments