44
55from abc import ABC , abstractmethod
66from math import isclose
7- from typing import TYPE_CHECKING , Literal , Optional , Union
7+ from typing import TYPE_CHECKING , Any , Literal , Optional , Union
88
99import numpy as np
1010from pydantic import (
@@ -112,7 +112,7 @@ class ModeSortSpec(Tidy3dBaseModel):
112112 title = "Filtering order" ,
113113 description = "Select whether the first group contains values over or under the reference." ,
114114 )
115- bounding_box : Optional [Box ] = pd . Field (
115+ bounding_box : Optional [Box ] = Field (
116116 None ,
117117 title = "Bounding box" ,
118118 description = (
@@ -121,7 +121,7 @@ class ModeSortSpec(Tidy3dBaseModel):
121121 "still intersect the monitor plane. Required when filtering or sorting with that key."
122122 ),
123123 )
124- keep_modes : Union [Literal ["all" ], Literal ["filtered" ], pd . PositiveInt ] = pd . Field (
124+ keep_modes : Union [Literal ["all" ], Literal ["filtered" ], PositiveInt ] = Field (
125125 "all" ,
126126 title = "Keep Modes" ,
127127 description = (
@@ -151,19 +151,22 @@ class ModeSortSpec(Tidy3dBaseModel):
151151 description = _build_sort_order_description (),
152152 )
153153
154- @pd . validator ( "sort_order" , always = True )
155- @skip_if_fields_missing ([ "sort_key" , "sort_reference" ])
156- def _set_default_sort_order (cls , val , values ) :
154+ @model_validator ( mode = "before" )
155+ @classmethod
156+ def _set_default_sort_order (cls , data : dict [ str , Any ]) -> dict [ str , Any ] :
157157 """Set default sort order based on sort_key and sort_reference."""
158+ val = data .get ("sort_order" )
158159 if val is not None :
159- return val
160- sort_reference = values .get ("sort_reference" )
160+ return data
161+ sort_reference = data .get ("sort_reference" )
161162 # When sorting by distance to a reference, ascending is natural (closest first)
162163 if sort_reference is not None :
163- return "ascending"
164+ data ["sort_order" ] = "ascending"
165+ return data
164166 # Otherwise, use the natural default for each key
165- sort_key = values .get ("sort_key" )
166- return MODE_DATA_KEY_SORT_ORDER .get (sort_key , "ascending" )
167+ sort_key = data .get ("sort_key" , "n_eff" )
168+ data ["sort_order" ] = MODE_DATA_KEY_SORT_ORDER .get (sort_key , "ascending" )
169+ return data
167170
168171 # Frequency tracking - applied after sorting and filtering
169172 track_freq : Optional [TrackFreq ] = Field (
@@ -175,23 +178,25 @@ def _set_default_sort_order(cls, val, values):
175178 "while at other frequencies it can change depending on the mode tracking." ,
176179 )
177180
178- @pd .validator ("keep_modes" , always = True )
179- def _drop_requires_filter (cls , val , values ):
180- if val == "filtered" and values .get ("filter_key" ) is None :
181+ @model_validator (mode = "after" )
182+ def _drop_requires_filter (self : Self ) -> Self :
183+ val = self .keep_modes
184+ if val == "filtered" and self .filter_key is None :
181185 raise ValidationError (
182186 "ModeSortSpec.keep_modes 'filtered' requires 'filter_key' to be set."
183187 )
184- return val
188+ return self
185189
186- @pd .root_validator (skip_on_failure = True )
187- def _bounding_box_required_for_fill_fraction (cls , values ):
188- bbox = values .get ("bounding_box" )
189- keys = (values .get ("filter_key" ), values .get ("sort_key" ))
190+ @model_validator (mode = "before" )
191+ @classmethod
192+ def _bounding_box_required_for_fill_fraction (cls , data : dict [str , Any ]) -> dict [str , Any ]:
193+ bbox = data .get ("bounding_box" )
194+ keys = (data .get ("filter_key" ), data .get ("sort_key" ))
190195 if any (key == "fill_fraction_box" for key in keys ) and bbox is None :
191196 raise ValidationError (
192197 "ModeSortSpec.bounding_box must be set when using 'fill_fraction_box'."
193198 )
194- return values
199+ return data
195200
196201 @property
197202 def has_custom_sort_or_filter (self ) -> bool :
@@ -210,7 +215,7 @@ def has_custom_sort_or_filter(self) -> bool:
210215 self .filter_key is not None
211216 or self .sort_key != "n_eff"
212217 or self .sort_reference is not None
213- or self .sort_order != "descending"
218+ or ( self .sort_order is not None and self . sort_order != "descending" )
214219 or self .keep_modes != "all"
215220 )
216221
@@ -692,11 +697,12 @@ class AbstractModeSpec(Tidy3dBaseModel, ABC):
692697 "not be ``None``) to ensure consistent mode ordering across frequencies." ,
693698 )
694699
695- @pd .validator ("sort_spec" , always = True )
696- def _keep_modes_at_most_num_modes (cls , val , values ):
700+ @model_validator (mode = "after" )
701+ def _keep_modes_at_most_num_modes (self : Self ) -> Self :
702+ val = self .sort_spec
697703 if val is not None :
698704 if isinstance (val .keep_modes , int ):
699- num_modes = values . get ( " num_modes" )
705+ num_modes = self . num_modes
700706 if val .keep_modes > num_modes :
701707 raise ValidationError (
702708 "ModeSortSpec.keep_modes cannot be larger than 'num_modes'. "
@@ -708,7 +714,7 @@ def _keep_modes_at_most_num_modes(cls, val, values):
708714 "'keep_modes=\" filtered\" ' to keep exactly those modes matching "
709715 "the filter, or setting 'keep_modes=\" all\" ' to keep all modes."
710716 )
711- return val
717+ return self
712718
713719 @field_validator ("group_index_step" , mode = "before" )
714720 @classmethod
0 commit comments