1414from functools import total_ordering , wraps
1515from math import ceil
1616from pathlib import Path
17- from typing import Any , Callable , Literal , Optional , TypeVar , Union
17+ from typing import Any , Callable , Literal , Optional , TypeVar , Union , get_args , get_origin
1818
1919import h5py
2020import numpy as np
@@ -189,6 +189,131 @@ def _hash_self(self) -> str:
189189 self .to_hdf5 (bf )
190190 return hashlib .sha256 (bf .getvalue ()).hexdigest ()
191191
192+ def _preprocess_update_values (self , update : Mapping [str , Any ]) -> dict [str , Any ]:
193+ """Preprocess update values to convert lists to tuples where appropriate.
194+
195+ This helps avoid Pydantic v2 serialization warnings when using `model_copy()`
196+ with list values for tuple fields.
197+ """
198+ if not update :
199+ return {}
200+
201+ def get_tuple_element_type (annotation ) -> Optional [type ]:
202+ """Get the element type of a tuple annotation if it has one consistent type."""
203+ origin = get_origin (annotation )
204+ if origin is tuple :
205+ args = get_args (annotation )
206+ if args :
207+ # Check if it's a homogeneous tuple like tuple[bool, ...] or tuple[str, ...]
208+ if len (args ) == 2 and args [1 ] is ...:
209+ return args [0 ]
210+ # Check if all elements have the same type
211+ if all (arg == args [0 ] for arg in args ):
212+ return args [0 ]
213+ return None
214+
215+ def should_convert_to_tuple (annotation ) -> tuple [bool , Optional [type ]]:
216+ """Check if the given annotation represents a tuple type and return element type if any."""
217+ origin = get_origin (annotation )
218+
219+ if origin is tuple :
220+ return True , get_tuple_element_type (annotation )
221+
222+ # Union types containing tuple
223+ if origin is Union :
224+ args = get_args (annotation )
225+ for arg in args :
226+ if get_origin (arg ) is tuple :
227+ return True , get_tuple_element_type (arg )
228+
229+ return False , None
230+
231+ def convert_value (value : Any , field_info ) -> Any :
232+ """Convert value based on field type information."""
233+ annotation = field_info .annotation
234+
235+ # Handle list/tuple to tuple conversion with proper element types
236+ is_tuple , element_type = should_convert_to_tuple (annotation )
237+
238+ # Check if value is a numpy array and needs to be converted to tuple
239+ try :
240+ import numpy as np
241+
242+ if isinstance (value , np .ndarray ) and is_tuple :
243+ # Convert numpy array to list first
244+ value = value .tolist ()
245+ except ImportError :
246+ pass
247+
248+ # Handle autograd SequenceBox - convert to tuple
249+ if (
250+ is_tuple
251+ and hasattr (value , "__class__" )
252+ and value .__class__ .__name__ == "SequenceBox"
253+ ):
254+ # SequenceBox is iterable, so convert it to tuple
255+ return tuple (value )
256+
257+ if isinstance (value , (list , tuple )) and is_tuple :
258+ # Convert elements based on element type
259+ if element_type is bool :
260+ # Convert integers to booleans
261+ value = [bool (item ) if isinstance (item , int ) else item for item in value ]
262+ elif element_type is str :
263+ # Ensure all elements are strings
264+ value = [str (item ) if not isinstance (item , str ) else item for item in value ]
265+ else :
266+ # Check if it's a numpy array or contains numpy types
267+ try :
268+ import numpy as np
269+
270+ if any (isinstance (item , np .generic ) for item in value ):
271+ # Convert numpy types to Python types
272+ value = [
273+ item .item () if isinstance (item , np .generic ) else item
274+ for item in value
275+ ]
276+ except ImportError :
277+ pass
278+ return tuple (value )
279+
280+ # Handle int to bool conversion
281+ if annotation is bool and isinstance (value , int ):
282+ return bool (value )
283+
284+ # Handle dict to Tidy3dBaseModel conversion
285+ if isinstance (value , dict ):
286+ # Check if the annotation is a Tidy3dBaseModel subclass
287+ origin = get_origin (annotation )
288+ if origin is None :
289+ # Not a generic type, check if it's a direct subclass
290+ try :
291+ if isinstance (annotation , type ) and issubclass (annotation , Tidy3dBaseModel ):
292+ return annotation (** value )
293+ except (TypeError , AttributeError ):
294+ pass
295+ elif origin is Union :
296+ # For Union types, try to convert to the first matching Tidy3dBaseModel type
297+ args = get_args (annotation )
298+ for arg in args :
299+ try :
300+ if isinstance (arg , type ) and issubclass (arg , Tidy3dBaseModel ):
301+ return arg (** value )
302+ except (TypeError , AttributeError , ValueError ):
303+ continue
304+
305+ return value
306+
307+ processed = {}
308+ for field_name , value in update .items ():
309+ if field_name in self .model_fields :
310+ field_info = self .model_fields [field_name ]
311+ processed [field_name ] = convert_value (value , field_info )
312+ else :
313+ processed [field_name ] = value
314+
315+ return processed
316+
192317 def copy (
193318 self , * , deep : bool = True , validate : bool = True , update : Mapping [str , Any ] | None = None
194319 ) -> Self :
@@ -209,6 +334,10 @@ def copy(
209334 if invalid :
210335 raise KeyError (f"'{ self .type } ' received invalid fields on copy: { invalid } " )
211336
337+ # preprocess update values to convert lists to tuples where appropriate
338+ if update :
339+ update = self ._preprocess_update_values (update )
340+
212341 new_model = self .model_copy (deep = deep , update = update )
213342
214343 if validate :
@@ -275,8 +404,8 @@ def updated_copy(
275404
276405 @staticmethod
277406 def _core_model_traversal (
278- current_obj : Any , current_path_segments : list [str ]
279- ) -> Iterator [tuple [Self , list [str ]]]:
407+ current_obj : Any , current_path_segments : tuple [str , ... ]
408+ ) -> Iterator [tuple [Self , tuple [str , ... ]]]:
280409 """
281410 Recursively traverses a model structure yielding Tidy3dBaseModel instances and their paths.
282411
@@ -288,16 +417,16 @@ def _core_model_traversal(
288417 current_obj : Any
289418 The current object in the traversal, which can be a :class:`Tidy3dBaseModel`,
290419 list, tuple, or other type.
291- current_path_segments : list [str]
292- A list of strings representing the path segments from the initial model
420+ current_path_segments : tuple [str, ... ]
421+ A tuple of strings representing the path segments from the initial model
293422 to the ``current_obj``.
294423
295424 Returns
296425 -------
297- Iterator[tuple[Self, list [str]]]
426+ Iterator[tuple[Self, tuple [str, ... ]]]
298427 An iterator yielding tuples, where the first element is a found :class:`Tidy3dBaseModel` instance
299- and the second is a list of strings representing the path to that instance
300- from the initial object. The path for the top-level model itself will be an empty list .
428+ and the second is a tuple of strings representing the path to that instance
429+ from the initial object. The path for the top-level model itself will be an empty tuple .
301430 """
302431 if isinstance (current_obj , Tidy3dBaseModel ):
303432 yield current_obj , current_path_segments
@@ -311,12 +440,12 @@ def _core_model_traversal(
311440
312441 field_value = getattr (current_obj , field_name )
313442 yield from Tidy3dBaseModel ._core_model_traversal (
314- field_value , [ * current_path_segments , field_name ]
443+ field_value , ( * current_path_segments , field_name )
315444 )
316445 elif isinstance (current_obj , (list , tuple )):
317446 for index , item in enumerate (current_obj ):
318447 yield from Tidy3dBaseModel ._core_model_traversal (
319- item , [ * current_path_segments , str (index )]
448+ item , ( * current_path_segments , str (index ))
320449 )
321450
322451 def find_paths (self , target_field_name : str , target_field_value : Any = Undefined ) -> list [str ]:
@@ -357,7 +486,7 @@ def find_paths(self, target_field_name: str, target_field_value: Any = Undefined
357486 found_paths_set = set ()
358487
359488 for sub_model_instance , path_segments_to_sub_model in Tidy3dBaseModel ._core_model_traversal (
360- self , []
489+ self , ()
361490 ):
362491 if target_field_name in sub_model_instance .model_fields :
363492 passes_value_filter = True
@@ -415,7 +544,7 @@ def find_submodels(self, target_type: Self) -> list[Self]:
415544 """
416545 found_models_dict = {}
417546
418- for sub_model_candidate , _ in Tidy3dBaseModel ._core_model_traversal (self , [] ):
547+ for sub_model_candidate , _ in Tidy3dBaseModel ._core_model_traversal (self , () ):
419548 if isinstance (sub_model_candidate , target_type ):
420549 if sub_model_candidate not in found_models_dict :
421550 found_models_dict [sub_model_candidate ] = True
0 commit comments