Skip to content

Commit a29b431

Browse files
hamelphiTorax team
authored andcommitted
Extract base class from IonCyclotronSourceConfig for ICRH.
Rename IonCyclotronSourceConfig to ToricNNIonCyclotronSourceConfig and introduce a base IonCyclotronSourceConfig class that holds shared ICRH configuration (P_total, absorption_fraction, mode). This prepares for future ICRH model variants beyond the ToricNN surrogate. Key changes: - IonCyclotronSourceConfig is now the base class with shared fields. - ToricNNIonCyclotronSourceConfig inherits from it with ToricNN-specific fields (model_path, wall geometry, frequency, minority params). - Update pydantic_model.py type annotation to ToricNNIonCyclotronSourceConfig. - Update model_config.py validator to use isinstance check. - Update tests to reference ToricNNIonCyclotronSourceConfig. PiperOrigin-RevId: 892473531
1 parent 412b9be commit a29b431

4 files changed

Lines changed: 47 additions & 26 deletions

File tree

torax/_src/sources/ion_cyclotron_source.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Surrogate model for ion-cyclotron resonance heating (ICRH) model."""
14+
"""Ion-cyclotron resonance heating (ICRH) source models."""
1515

1616
import dataclasses
1717
import functools
@@ -580,7 +580,7 @@ def _build_fast_ions(
580580

581581
@dataclasses.dataclass(kw_only=True, frozen=True, eq=False)
582582
class IonCyclotronSource(source.Source):
583-
"""Ion cyclotron source with surrogate model."""
583+
"""Ion cyclotron source."""
584584

585585
SOURCE_NAME: ClassVar[str] = 'icrh'
586586
AFFECTED_CORE_PROFILES: ClassVar[tuple[source.AffectedCoreProfile, ...]] = (
@@ -615,7 +615,37 @@ def _icrh_model_func_with_toric_nn(
615615

616616

617617
class IonCyclotronSourceConfig(base.SourceModelBase):
618-
"""Configuration for the IonCyclotronSource.
618+
"""Base configuration for IonCyclotronSource.
619+
620+
This base class contains fields common to all ICRH model implementations.
621+
Subclasses implement the specific model logic (e.g., ToricNN, scaled curves)
622+
and must override `model_name` with a `Literal` to serve as discriminator.
623+
624+
Attributes:
625+
model_name: Discriminator field for Pydantic. Subclasses must override with
626+
a `Literal` value.
627+
P_total: Total heating power [W].
628+
absorption_fraction: Fraction of absorbed power.
629+
mode: Defines how the source values are computed.
630+
"""
631+
632+
model_name: Annotated[str, torax_pydantic.JAX_STATIC] = ''
633+
P_total: torax_pydantic.TimeVaryingScalar = torax_pydantic.ValidatedDefault(
634+
10e6
635+
)
636+
absorption_fraction: torax_pydantic.PositiveTimeVaryingScalar = (
637+
torax_pydantic.ValidatedDefault(1.0)
638+
)
639+
mode: Annotated[source_runtime_params_lib.Mode, torax_pydantic.JAX_STATIC] = (
640+
source_runtime_params_lib.Mode.MODEL_BASED
641+
)
642+
643+
def build_source(self) -> IonCyclotronSource:
644+
return IonCyclotronSource(model_func=self.model_func)
645+
646+
647+
class ToricNNIonCyclotronSourceConfig(IonCyclotronSourceConfig):
648+
"""Configuration for the IonCyclotronSource using the ToricNN model.
619649
620650
Attributes:
621651
model_path: Path to JSON weights and model config of ToricNN model.
@@ -633,8 +663,6 @@ class IonCyclotronSourceConfig(base.SourceModelBase):
633663
plasma_composition instead of using minority_concentration parameter. The
634664
species can be either a main ion (if hydrogenic) or an impurity (if
635665
helium).
636-
P_total: Total heating power [W].
637-
absorption_fraction: Fraction of absorbed power.
638666
"""
639667

640668
model_name: Annotated[Literal['toric_nn'], torax_pydantic.JAX_STATIC] = (
@@ -651,15 +679,6 @@ class IonCyclotronSourceConfig(base.SourceModelBase):
651679
torax_pydantic.ValidatedDefault(0.03)
652680
)
653681
minority_species: Annotated[str | None, torax_pydantic.JAX_STATIC] = None
654-
P_total: torax_pydantic.TimeVaryingScalar = torax_pydantic.ValidatedDefault(
655-
10e6
656-
)
657-
absorption_fraction: torax_pydantic.PositiveTimeVaryingScalar = (
658-
torax_pydantic.ValidatedDefault(1.0)
659-
)
660-
mode: Annotated[source_runtime_params_lib.Mode, torax_pydantic.JAX_STATIC] = (
661-
source_runtime_params_lib.Mode.MODEL_BASED
662-
)
663682

664683
@property
665684
def model_func(self) -> source.SourceProfileFunction:
@@ -713,6 +732,3 @@ def build_runtime_params(
713732
P_total=self.P_total.get_value(t),
714733
absorption_fraction=self.absorption_fraction.get_value(t),
715734
)
716-
717-
def build_source(self) -> IonCyclotronSource:
718-
return IonCyclotronSource(model_func=self.model_func)

torax/_src/sources/pydantic_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class Sources(torax_pydantic.BaseModelFrozen):
9191
discriminator='model_name',
9292
default=None,
9393
)
94-
icrh: ion_cyclotron_source_lib.IonCyclotronSourceConfig | None = (
94+
icrh: ion_cyclotron_source_lib.ToricNNIonCyclotronSourceConfig | None = (
9595
pydantic.Field(
9696
discriminator='model_name',
9797
default=None,

torax/_src/sources/tests/ion_cyclotron_source_test.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,14 @@ def setUp(self):
100100
self.dummy_input = model_input
101101
self.dummy_output = model_output
102102
super().setUp(
103-
source_config_class=ion_cyclotron_source.IonCyclotronSourceConfig,
103+
source_config_class=ion_cyclotron_source.ToricNNIonCyclotronSourceConfig,
104104
source_name=ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME,
105105
)
106106
# pytype: enable=signature-mismatch
107107

108108
def config_raises_if_model_path_does_not_exist(self):
109109
with self.assertRaises(FileNotFoundError):
110-
ion_cyclotron_source.IonCyclotronSourceConfig.from_dict(
110+
ion_cyclotron_source.ToricNNIonCyclotronSourceConfig.from_dict(
111111
{"model_path": "/tmp/non_existent_file.json"}
112112
)
113113

@@ -338,21 +338,25 @@ def test_source_raises_if_minority_species_not_in_composition(self):
338338

339339
def test_minority_concentration_warning_by_default(self):
340340
with self.assertLogs(level="WARNING") as cm:
341-
ion_cyclotron_source.IonCyclotronSourceConfig()
341+
ion_cyclotron_source.ToricNNIonCyclotronSourceConfig()
342342
self.assertTrue(
343343
any("minority_concentration is provided" in o for o in cm.output)
344344
)
345345

346346
def test_minority_concentration_warning_when_explicitly_set_to_non_none(self):
347347
with self.assertLogs(level="WARNING") as cm:
348-
ion_cyclotron_source.IonCyclotronSourceConfig(minority_concentration=0.05)
348+
ion_cyclotron_source.ToricNNIonCyclotronSourceConfig(
349+
minority_concentration=0.05
350+
)
349351
self.assertTrue(
350352
any("minority_concentration is provided" in o for o in cm.output)
351353
)
352354

353355
def test_minority_concentration_no_warning_when_none(self):
354356
with self.assertNoLogs(level="WARNING"):
355-
ion_cyclotron_source.IonCyclotronSourceConfig(minority_concentration=None)
357+
ion_cyclotron_source.ToricNNIonCyclotronSourceConfig(
358+
minority_concentration=None
359+
)
356360

357361
def test_icrh_returns_fast_ion_data(self):
358362
config = default_configs.get_default_config_dict()

torax/_src/torax_pydantic/model_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,10 @@ def _validate_toric_nn_he3_presence(self) -> typing_extensions.Self:
350350
is not set.
351351
"""
352352
if (
353-
self.sources.icrh is not None
354-
and self.sources.icrh.model_name
355-
== ion_cyclotron_source_lib.DEFAULT_MODEL_FUNCTION_NAME
353+
isinstance(
354+
self.sources.icrh,
355+
ion_cyclotron_source_lib.ToricNNIonCyclotronSourceConfig,
356+
)
356357
and self.sources.icrh.minority_species is not None
357358
):
358359
he3_present = (

0 commit comments

Comments
 (0)