Skip to content

Commit 6363299

Browse files
committed
added BaseModel__To__Type_Safe
1 parent e5f348b commit 6363299

2 files changed

Lines changed: 574 additions & 0 deletions

File tree

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
from typing import Type, Dict, Any, Optional, get_args, Union, List
2+
from osbot_utils.type_safe.decorators.type_safe import type_safe
3+
from osbot_utils.type_safe.shared.Type_Safe__Shared__Variables import IMMUTABLE_TYPES
4+
from pydantic import BaseModel
5+
from pydantic_core import PydanticUndefined
6+
from osbot_utils.type_safe.Type_Safe import Type_Safe
7+
from osbot_utils.type_safe.shared.Type_Safe__Cache import type_safe_cache
8+
9+
10+
class BaseModel__To__Type_Safe(Type_Safe):
11+
class_cache: Dict[Type[BaseModel], Type[Type_Safe]] # Cache for mapped classes
12+
13+
@type_safe
14+
def convert_class(self, basemodel_class : Type[BaseModel] # BaseModel class to convert
15+
) -> Type[Type_Safe]: # Returns Type_Safe class
16+
if basemodel_class in self.class_cache: # Check cache first
17+
return self.class_cache[basemodel_class]
18+
19+
class_name = basemodel_class.__name__.replace('__BaseModel', '') # Generate class name
20+
if not class_name.endswith('__Type_Safe'): # Ensure proper naming
21+
class_name = f"{class_name}__Type_Safe"
22+
23+
annotations = {} # Build Type_Safe annotations
24+
defaults = {} # Build default values
25+
26+
for field_name, field_info in basemodel_class.model_fields.items(): # Process each field
27+
type_safe_type = self.convert_field_type(field_info.annotation) # Convert type annotation
28+
annotations[field_name] = type_safe_type
29+
30+
# Only add immutable defaults to class definition. Mutable defaults will be handled during instance creation
31+
if field_info.default is not PydanticUndefined: # Has an actual default
32+
if self.is_immutable_default(field_info.default): # Only immutable defaults
33+
defaults[field_name] = field_info.default
34+
elif field_info.default is None: # Explicit None is ok
35+
defaults[field_name] = None
36+
else:
37+
pass # Skip mutable defaults like lists, dicts, sets
38+
elif field_info.default_factory is not None: # Default factories
39+
pass # Don't add to class defaults, will handle in instance creation
40+
41+
type_safe_class = type(class_name , # Create Type_Safe class
42+
(Type_Safe,) , # Base class
43+
{**defaults , # Only immutable defaults
44+
'__annotations__': annotations})
45+
46+
self.class_cache[basemodel_class] = type_safe_class # Cache the result
47+
return type_safe_class
48+
49+
@type_safe
50+
def convert_instance(self, basemodel_instance : BaseModel # BaseModel instance to convert
51+
) -> Type_Safe: # Returns Type_Safe instance
52+
basemodel_class = type(basemodel_instance) # Get BaseModel class
53+
type_safe_class = self.convert_class(basemodel_class) # Get or create Type_Safe class
54+
55+
instance_data = self.extract_basemodel_data(basemodel_instance) # Extract data from BaseModel
56+
57+
constructor_args = {} # Separate data into constructor args and post-init assignments
58+
post_init_data = {}
59+
60+
for field_name, field_value in instance_data.items():
61+
if self.is_safe_for_constructor(field_value): # Simple types go to constructor
62+
constructor_args[field_name] = field_value
63+
else: # Complex types set after init
64+
post_init_data[field_name] = field_value
65+
66+
instance = type_safe_class(**constructor_args) # Create instance with safe constructor args
67+
68+
for field_name, field_value in post_init_data.items(): # Set complex fields after initialization
69+
setattr(instance, field_name, field_value)
70+
71+
return instance
72+
73+
def is_immutable_default(self, value : Any # Value to check
74+
) -> bool: # Returns True if immutable
75+
"""Check if a value is safe to use as a class-level default in Type_Safe."""
76+
if value is None:
77+
return True
78+
if type(value) in IMMUTABLE_TYPES:
79+
return True
80+
if isinstance(value, IMMUTABLE_TYPES):
81+
return True
82+
return False
83+
84+
def is_safe_for_constructor(self, value : Any # Value to check
85+
) -> bool: # Returns True if safe
86+
"""Check if a value is safe to pass to Type_Safe constructor."""
87+
if value is None:
88+
return True
89+
if isinstance(value, (str, int, float, bool, bytes)): # Primitives are safe
90+
return True
91+
if isinstance(value, Type_Safe): # Type_Safe instances ok
92+
return True
93+
return False # Lists, dicts, sets not safe
94+
95+
def convert_field_type(self, pydantic_type : Any # Pydantic type to convert
96+
) -> Any: # Returns Type_Safe compatible type
97+
origin = type_safe_cache.get_origin(pydantic_type)
98+
99+
if origin is list: # Handle List types
100+
args = get_args(pydantic_type)
101+
if args:
102+
inner_type = self.convert_field_type(args[0])
103+
return List[inner_type]
104+
return list
105+
106+
elif origin is dict: # Handle Dict types
107+
args = get_args(pydantic_type)
108+
if len(args) == 2:
109+
key_type = self.convert_field_type(args[0])
110+
value_type = self.convert_field_type(args[1])
111+
return Dict[key_type, value_type]
112+
return dict
113+
114+
elif origin is set: # Handle Set types
115+
args = get_args(pydantic_type)
116+
if args:
117+
inner_type = self.convert_field_type(args[0])
118+
from typing import Set
119+
return Set[inner_type]
120+
return set
121+
122+
elif origin in (Union, Optional): # Handle Union/Optional
123+
args = get_args(pydantic_type)
124+
converted_args = tuple(self.convert_field_type(arg) for arg in args)
125+
if origin is Optional:
126+
return Optional[converted_args[0]]
127+
return Union[converted_args]
128+
129+
if isinstance(pydantic_type, type) and issubclass(pydantic_type, BaseModel): # Handle nested BaseModel
130+
return self.convert_class(pydantic_type) # Recursively convert
131+
132+
return pydantic_type # Return standard types as-is
133+
134+
def extract_basemodel_data(self, basemodel_instance : BaseModel # Instance to extract from
135+
) -> Dict[str, Any]: # Returns extracted data
136+
data = basemodel_instance.model_dump() # Get all data as dict
137+
138+
result = {}
139+
for field_name, field_value in data.items():
140+
if field_value is None: # Skip None values
141+
result[field_name] = None
142+
elif isinstance(field_value, dict): # Handle dict fields
143+
result[field_name] = self.convert_dict_field(basemodel_instance, field_name, field_value)
144+
elif isinstance(field_value, list): # Handle list fields
145+
result[field_name] = self.convert_list_field(basemodel_instance, field_name, field_value)
146+
elif isinstance(field_value, set): # Handle set fields
147+
result[field_name] = self.convert_set_field(basemodel_instance, field_name, field_value)
148+
else:
149+
result[field_name] = self.convert_value(field_value) # Convert simple values
150+
151+
return result
152+
153+
def convert_dict_field(self, basemodel_instance : BaseModel , # Parent instance
154+
field_name : str , # Field name
155+
field_value : dict # Dict value to convert
156+
) -> Any: # Returns converted dict
157+
field_info = basemodel_instance.model_fields.get(field_name) # Get field metadata
158+
if not field_info:
159+
return field_value
160+
161+
origin = type_safe_cache.get_origin(field_info.annotation)
162+
if origin is dict: # It's a typed dict
163+
args = get_args(field_info.annotation)
164+
if len(args) == 2:
165+
key_type, value_type = args
166+
result = {}
167+
for k, v in field_value.items():
168+
converted_key = self.convert_value(k)
169+
converted_value = self.convert_nested_value(v, value_type)
170+
result[converted_key] = converted_value
171+
return result
172+
173+
return field_value # Return as-is if not typed
174+
175+
def convert_list_field(self, basemodel_instance : BaseModel , # Parent instance
176+
field_name : str , # Field name
177+
field_value : list # List value to convert
178+
) -> Any: # Returns converted list
179+
field_info = basemodel_instance.model_fields.get(field_name) # Get field metadata
180+
if not field_info:
181+
return field_value
182+
183+
origin = type_safe_cache.get_origin(field_info.annotation)
184+
if origin is list: # It's a typed list
185+
args = get_args(field_info.annotation)
186+
if args:
187+
item_type = args[0]
188+
result = []
189+
for item in field_value:
190+
converted_item = self.convert_nested_value(item, item_type)
191+
result.append(converted_item)
192+
return result
193+
194+
return field_value # Return as-is if not typed
195+
196+
def convert_set_field(self, basemodel_instance : BaseModel , # Parent instance
197+
field_name : str , # Field name
198+
field_value : set # Set value to convert
199+
) -> Any: # Returns converted set
200+
field_info = basemodel_instance.model_fields.get(field_name) # Get field metadata
201+
if not field_info:
202+
return field_value
203+
204+
origin = type_safe_cache.get_origin(field_info.annotation)
205+
if origin is set: # It's a typed set
206+
args = get_args(field_info.annotation)
207+
if args:
208+
item_type = args[0]
209+
result = set()
210+
for item in field_value:
211+
converted_item = self.convert_nested_value(item, item_type)
212+
result.add(converted_item)
213+
return result
214+
215+
return field_value # Return as-is if not typed
216+
217+
def convert_nested_value(self, value : Any , # Value to convert
218+
expected_type : Any # Expected type hint
219+
) -> Any: # Returns converted value
220+
if value is None:
221+
return None
222+
223+
if isinstance(expected_type, type) and issubclass(expected_type, BaseModel): # Nested BaseModel
224+
if isinstance(value, dict): # Dict representation
225+
nested_model = expected_type(**value) # Create BaseModel instance
226+
return self.convert_instance(nested_model) # Convert to Type_Safe
227+
elif isinstance(value, BaseModel): # Already BaseModel
228+
return self.convert_instance(value)
229+
230+
if isinstance(value, dict): # Nested dict
231+
origin = type_safe_cache.get_origin(expected_type)
232+
if origin is dict:
233+
args = get_args(expected_type)
234+
if len(args) == 2:
235+
key_type, value_type = args
236+
result = {}
237+
for k, v in value.items():
238+
result[k] = self.convert_nested_value(v, value_type)
239+
return result
240+
241+
if isinstance(value, list): # Nested list
242+
origin = type_safe_cache.get_origin(expected_type)
243+
if origin is list:
244+
args = get_args(expected_type)
245+
if args:
246+
item_type = args[0]
247+
return [self.convert_nested_value(item, item_type) for item in value]
248+
249+
return self.convert_value(value) # Simple value
250+
251+
def convert_value(self, value : Any # Value to convert
252+
) -> Any: # Returns converted value
253+
if isinstance(value, BaseModel): # BaseModel instance
254+
return self.convert_instance(value)
255+
elif isinstance(value, dict): # Dict that might be BaseModel data
256+
if '_type' in value or '__class__' in value: # Looks like serialized object
257+
return value # Let Type_Safe handle it
258+
return value # Return simple values as-is
259+
260+
261+
basemodel__to__type_safe = BaseModel__To__Type_Safe() # Singleton instance for convenience

0 commit comments

Comments
 (0)