|
34 | 34 | T = TypeVar("T", bound=BaseModel) |
35 | 35 |
|
36 | 36 |
|
37 | | -def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T: |
38 | | - """Creates a model from a response.""" |
39 | | - model_field_names = model_type.model_fields.keys() |
40 | | - filtered_response = {} |
41 | | - for key, value in response.items(): |
42 | | - snake_key = common.camel_to_snake(key) |
43 | | - if snake_key in model_field_names: |
44 | | - filtered_response[snake_key] = value |
45 | | - return model_type(**filtered_response) |
| 37 | +def create_from_response( |
| 38 | + model_type: Type[T], |
| 39 | + response: dict[str, Any], |
| 40 | + config: Any | None = None, |
| 41 | +) -> T: |
| 42 | + """Creates a model from a response.""" |
| 43 | + kwargs = ( |
| 44 | + { |
| 45 | + "config": { |
| 46 | + "response_schema": getattr(config, "response_schema", None), |
| 47 | + "response_json_schema": getattr( |
| 48 | + config, "response_json_schema", None |
| 49 | + ), |
| 50 | + "include_all_fields": getattr(config, "include_all_fields", None), |
| 51 | + } |
| 52 | + } |
| 53 | + if config |
| 54 | + else {} |
| 55 | + ) |
| 56 | + return model_type._from_response(response=response, kwargs=kwargs) # type: ignore[attr-defined,no-any-return] |
46 | 57 |
|
47 | 58 |
|
48 | 59 | def validate_multimodal_dataset_bigquery_uri( |
|
0 commit comments