|
26 | 26 | GenerativeNvidia, |
27 | 27 | GenerativeOllama, |
28 | 28 | GenerativeOpenAI, |
| 29 | + GenerativeXAI, |
29 | 30 | GenerativeProvider as GenerativeProviderGRPC, |
30 | 31 | GenerativeSearch, |
31 | 32 | ) |
@@ -398,6 +399,31 @@ def _to_grpc(self, opts: _GenerativeConfigRuntimeOptions) -> GenerativeProviderG |
398 | 399 | ) |
399 | 400 |
|
400 | 401 |
|
| 402 | +class _GenerativeXAI(_GenerativeConfigRuntime): |
| 403 | + generative: Union[GenerativeSearches, _EnumLikeStr] = Field( |
| 404 | + default=GenerativeSearches.XAI, frozen=True, exclude=True |
| 405 | + ) |
| 406 | + base_url: Optional[AnyHttpUrl] |
| 407 | + max_tokens: Optional[int] |
| 408 | + model: Optional[str] |
| 409 | + temperature: Optional[float] |
| 410 | + top_p: Optional[float] |
| 411 | + |
| 412 | + def _to_grpc(self, opts: _GenerativeConfigRuntimeOptions) -> GenerativeProviderGRPC: |
| 413 | + return GenerativeProviderGRPC( |
| 414 | + return_metadata=opts.return_metadata, |
| 415 | + xai=GenerativeXAI( |
| 416 | + base_url=_parse_anyhttpurl(self.base_url), |
| 417 | + max_tokens=self.max_tokens, |
| 418 | + model=self.model, |
| 419 | + temperature=self.temperature, |
| 420 | + top_p=self.top_p, |
| 421 | + images=_to_text_array(opts.images), |
| 422 | + image_properties=_to_text_array(opts.image_properties), |
| 423 | + ), |
| 424 | + ) |
| 425 | + |
| 426 | + |
401 | 427 | class GenerativeConfig: |
402 | 428 | """Use this factory class to create the correct object for the `generative_provider` argument in the search methods of the `.generate` namespace. |
403 | 429 |
|
@@ -918,6 +944,40 @@ def openai_azure( |
918 | 944 | is_azure=True, |
919 | 945 | ) |
920 | 946 |
|
| 947 | + @staticmethod |
| 948 | + def xai( |
| 949 | + *, |
| 950 | + base_url: Optional[str] = None, |
| 951 | + max_tokens: Optional[int] = None, |
| 952 | + model: Optional[str] = None, |
| 953 | + temperature: Optional[float] = None, |
| 954 | + top_p: Optional[float] = None, |
| 955 | + ) -> _GenerativeConfigRuntime: |
| 956 | + """Create a `_GenerativeXAI` object for use when performing AI generation using the `generative-xai` module. |
| 957 | +
|
| 958 | + See the [documentation](https://weaviate.io/developers/weaviate/modules/reader-generator-modules/generative-xai) |
| 959 | + for detailed usage. |
| 960 | +
|
| 961 | + Arguments: |
| 962 | + `base_url` |
| 963 | + The base URL where the API request should go. Defaults to `None`, which uses the server-defined default |
| 964 | + `max_tokens` |
| 965 | + The maximum number of tokens to generate. Defaults to `None`, which uses the server-defined default |
| 966 | + `model` |
| 967 | + The model to use. Defaults to `None`, which uses the server-defined default |
| 968 | + `temperature` |
| 969 | + The temperature to use. Defaults to `None`, which uses the server-defined default |
| 970 | + `top_p` |
| 971 | + The top P to use. Defaults to `None`, which uses the server-defined default |
| 972 | + """ |
| 973 | + return _GenerativeXAI( |
| 974 | + base_url=AnyUrl(base_url) if base_url is not None else None, |
| 975 | + max_tokens=max_tokens, |
| 976 | + model=model, |
| 977 | + temperature=temperature, |
| 978 | + top_p=top_p, |
| 979 | + ) |
| 980 | + |
921 | 981 |
|
922 | 982 | class _GroupedTask(BaseModel): |
923 | 983 | prompt: str |
@@ -1006,7 +1066,7 @@ def single_prompt( |
1006 | 1066 |
|
1007 | 1067 | @staticmethod |
1008 | 1068 | def __parse_images( |
1009 | | - images: Optional[Union[BLOB_INPUT, Iterable[BLOB_INPUT]]] |
| 1069 | + images: Optional[Union[BLOB_INPUT, Iterable[BLOB_INPUT]]], |
1010 | 1070 | ) -> Optional[Iterable[str]]: |
1011 | 1071 | if isinstance(images, (str, Path, BufferedReader)): |
1012 | 1072 | return ( |
|
0 commit comments