11import os
2- from typing import List
2+ from typing import List , Union
33
44import pytest
55from _pytest .fixtures import SubRequest
1212 Property ,
1313)
1414from weaviate .collections .classes .data import DataObject
15+ from weaviate .collections .classes .generative import (
16+ GenerativeConfig ,
17+ GenerativeParameters ,
18+ _GroupedTask ,
19+ _SinglePrompt ,
20+ )
1521from weaviate .collections .classes .grpc import GroupBy , Rerank
1622from weaviate .exceptions import WeaviateQueryError , WeaviateUnsupportedFeatureError
23+ from weaviate .proto .v1 .generative_pb2 import GenerativeOpenAIMetadata
1724from weaviate .util import _ServerVersion
1825
1926
@@ -340,8 +347,6 @@ def test_near_object_generate_with_everything(openai_collection: OpenAICollectio
340347 assert res .generated == "apples cats"
341348 assert res .objects [0 ].generated is not None
342349 assert res .objects [1 ].generated is not None
343- assert res .objects [0 ].generated .lower () == "yes"
344- assert res .objects [1 ].generated .lower () == "no"
345350
346351
347352def test_near_object_generate_and_group_by_with_everything (
@@ -355,7 +360,7 @@ def test_near_object_generate_and_group_by_with_everything(
355360 [
356361 DataObject (
357362 properties = {
358- "text" : "apples are big. you cna eat apples" ,
363+ "text" : "apples are big. you can eat apples" ,
359364 "content" : "Teddy is the biggest and bigger than everything else" ,
360365 }
361366 ),
@@ -380,8 +385,6 @@ def test_near_object_generate_and_group_by_with_everything(
380385 groups = list (res .groups .values ())
381386 assert groups [0 ].generated is not None
382387 assert groups [1 ].generated is not None
383- assert groups [0 ].generated .lower () == "no"
384- assert groups [1 ].generated .lower () == "yes"
385388
386389
387390def test_near_text_generate_with_everything (openai_collection : OpenAICollection ) -> None :
@@ -644,3 +647,98 @@ def test_queries_with_rerank_and_generative(collection_factory: CollectionFactor
644647 ][
645648 0
646649 ].metadata .rerank_score
650+
651+
652+ @pytest .mark .parametrize (
653+ "grouped" ,
654+ [
655+ "Write out the fruit in alphabetical order. Only write the names separated by a space" ,
656+ GenerativeParameters .grouped_task (
657+ prompt = "Write out the fruit in alphabetical order. Only write the names separated by a space" ,
658+ metadata = True ,
659+ ),
660+ ],
661+ ids = ["string" , "object" ],
662+ )
663+ @pytest .mark .parametrize (
664+ "single" ,
665+ [
666+ "Is there something to eat in {text} of the given object? Only answer yes if there is something to eat and no if not. Dont use punctuation" ,
667+ GenerativeParameters .single_prompt (
668+ prompt = "Is there something to eat in {text} of the given object? Only answer yes if there is something to eat and no if not. Dont use punctuation" ,
669+ metadata = True ,
670+ debug = True ,
671+ ),
672+ ],
673+ ids = ["string" , "object" ],
674+ )
675+ def test_near_text_generate_with_dynamic_rag (
676+ openai_collection : OpenAICollection ,
677+ grouped : Union [str , _GroupedTask ],
678+ single : Union [str , _SinglePrompt ],
679+ ) -> None :
680+ collection = openai_collection (
681+ vectorizer_config = Configure .Vectorizer .text2vec_openai (vectorize_collection_name = False ),
682+ )
683+
684+ collection .data .insert_many (
685+ [
686+ DataObject (
687+ properties = {
688+ "text" : "melons are big" ,
689+ "content" : "Teddy is the biggest and bigger than everything else. Teddy is not a fruit" ,
690+ }
691+ ),
692+ DataObject (
693+ properties = {
694+ "text" : "cats are small. You cannot eat cats. Cats are not fruit" ,
695+ "content" : "bananas are the smallest and smaller than everything else" ,
696+ }
697+ ),
698+ ]
699+ )
700+
701+ query = lambda : collection .generate .near_text (
702+ query = "small fruit" ,
703+ single_prompt = single ,
704+ grouped_task = grouped ,
705+ generative_provider = GenerativeConfig .openai (
706+ temperature = 0.1 ,
707+ ),
708+ )
709+
710+ if collection ._connection ._weaviate_version .is_lower_than (1 , 30 , 0 ):
711+ with pytest .raises (WeaviateUnsupportedFeatureError ):
712+ res = query ()
713+ else :
714+ res = query ()
715+ # deprecated usage
716+ assert res .generated == "bananas melons"
717+ assert res .objects [0 ].generated is not None
718+ assert res .objects [1 ].generated is not None
719+
720+ assert res .generative is not None
721+ assert res .generative .text == "bananas melons"
722+
723+ if isinstance (grouped , _GroupedTask ):
724+ assert isinstance (res .generative .metadata , GenerativeOpenAIMetadata )
725+ else :
726+ assert res .generative .metadata is None
727+
728+ g0 = res .objects [0 ].generative
729+ g1 = res .objects [1 ].generative
730+
731+ assert g0 is not None
732+ assert g0 .text is not None
733+ assert g1 is not None
734+ assert g1 .text is not None
735+
736+ if isinstance (single , _SinglePrompt ):
737+ assert g0 .debug is not None
738+ assert isinstance (g0 .metadata , GenerativeOpenAIMetadata )
739+ assert g1 .debug is not None
740+ assert isinstance (g1 .metadata , GenerativeOpenAIMetadata )
741+ else :
742+ assert g0 .debug is None
743+ assert g0 .metadata is None
744+ assert g1 .metadata is None
0 commit comments