1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License
14+ import mock
1415
1516from google .cloud .firestore_v1 .pipeline_source import PipelineSource
1617from google .cloud .firestore_v1 .pipeline import Pipeline
1920from google .cloud .firestore_v1 .async_client import AsyncClient
2021from google .cloud .firestore_v1 import pipeline_stages as stages
2122from google .cloud .firestore_v1 .base_document import BaseDocumentReference
23+ from google .cloud .firestore_v1 .query import Query
24+ from google .cloud .firestore_v1 .async_query import AsyncQuery
2225
2326
2427class TestPipelineSource :
@@ -27,6 +30,9 @@ class TestPipelineSource:
2730 def _make_client (self ):
2831 return Client ()
2932
33+ def _make_query (self ):
34+ return Query (mock .Mock ())
35+
3036 def test_make_from_client (self ):
3137 instance = self ._make_client ().pipeline ()
3238 assert isinstance (instance , PipelineSource )
@@ -36,6 +42,23 @@ def test_create_pipeline(self):
3642 ppl = instance ._create_pipeline (None )
3743 assert isinstance (ppl , self ._expected_pipeline_type )
3844
45+ def test_create_from_mock (self ):
46+ mock_query = mock .Mock ()
47+ expected = object ()
48+ mock_query ._build_pipeline .return_value = expected
49+ instance = self ._make_client ().pipeline ()
50+ got = instance .create_from (mock_query )
51+ assert got == expected
52+ assert mock_query ._build_pipeline .call_count == 1
53+ assert mock_query ._build_pipeline .call_args_list [0 ][0 ][0 ] == instance
54+
55+ def test_create_from_query (self ):
56+ query = self ._make_query ()
57+ instance = self ._make_client ().pipeline ()
58+ ppl = instance .create_from (query )
59+ assert isinstance (ppl , self ._expected_pipeline_type )
60+ assert len (ppl .stages ) == 1
61+
3962 def test_collection (self ):
4063 instance = self ._make_client ().pipeline ()
4164 ppl = instance .collection ("path" )
@@ -98,3 +121,6 @@ class TestPipelineSourceWithAsyncClient(TestPipelineSource):
98121
99122 def _make_client (self ):
100123 return AsyncClient ()
124+
125+ def _make_query (self ):
126+ return AsyncQuery (mock .Mock ())
0 commit comments