Skip to content

Commit 4323798

Browse files
theo-brownTorax team
authored andcommitted
Create NoSaturationModel and NoFormationModel
Useful for testing and debugging PiperOrigin-RevId: 877364049
1 parent 8f36b6c commit 4323798

3 files changed

Lines changed: 132 additions & 2 deletions

File tree

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2026 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""No pedestal formation model."""
16+
17+
import dataclasses
18+
import jax
19+
from torax._src import state
20+
from torax._src.config import runtime_params as runtime_params_lib
21+
from torax._src.geometry import geometry
22+
from torax._src.pedestal_model import pedestal_model_output
23+
from torax._src.pedestal_model.formation import base
24+
from torax._src.sources import source_profiles as source_profiles_lib
25+
26+
# pylint: disable=invalid-name
27+
28+
29+
@jax.tree_util.register_dataclass
30+
@dataclasses.dataclass(frozen=True, eq=False)
31+
class NoFormationModel(base.FormationModel):
32+
"""No pedestal formation model."""
33+
34+
def __call__(
35+
self,
36+
runtime_params: runtime_params_lib.RuntimeParams,
37+
geo: geometry.Geometry,
38+
core_profiles: state.CoreProfiles,
39+
core_sources: source_profiles_lib.SourceProfiles,
40+
) -> pedestal_model_output.TransportMultipliers:
41+
"""Returns no transport multipliers."""
42+
return pedestal_model_output.TransportMultipliers.default()

torax/_src/pedestal_model/pydantic_model.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,37 @@
2626
from torax._src.pedestal_model import set_pped_tpedratio_nped
2727
from torax._src.pedestal_model import set_tped_nped
2828
from torax._src.pedestal_model.formation import martin_formation_model
29+
from torax._src.pedestal_model.formation import no_formation_model
30+
from torax._src.pedestal_model.saturation import no_saturation_model
2931
from torax._src.pedestal_model.saturation import profile_value_saturation_model
3032
from torax._src.torax_pydantic import torax_pydantic
3133

3234
# pylint: disable=invalid-name
3335

3436

37+
class NoFormation(torax_pydantic.BaseModelFrozen):
38+
"""Configuration for No formation model."""
39+
40+
model_name: Annotated[Literal["no_formation"], torax_pydantic.JAX_STATIC] = (
41+
"no_formation"
42+
)
43+
44+
def build_formation_model(
45+
self,
46+
) -> no_formation_model.NoFormationModel:
47+
return no_formation_model.NoFormationModel()
48+
49+
def build_runtime_params(
50+
self, t: chex.Numeric
51+
) -> runtime_params.FormationRuntimeParams:
52+
del t
53+
return runtime_params.FormationRuntimeParams(
54+
sigmoid_width=0.0,
55+
sigmoid_offset=0.0,
56+
sigmoid_exponent=1.0,
57+
)
58+
59+
3560
# TODO(b/323504363): Generalise to pedestal formation models based on power
3661
# thresholds (e.g. Metal Wall scaling), not just Martin scaling.
3762
class MartinFormation(torax_pydantic.BaseModelFrozen):
@@ -87,6 +112,29 @@ def build_runtime_params(
87112
)
88113

89114

115+
class NoSaturation(torax_pydantic.BaseModelFrozen):
116+
"""Configuration for No saturation model."""
117+
118+
model_name: Annotated[Literal["no_saturation"], torax_pydantic.JAX_STATIC] = (
119+
"no_saturation"
120+
)
121+
122+
def build_saturation_model(
123+
self,
124+
) -> no_saturation_model.NoSaturationModel:
125+
return no_saturation_model.NoSaturationModel()
126+
127+
def build_runtime_params(
128+
self, t: chex.Numeric
129+
) -> runtime_params.SaturationRuntimeParams:
130+
del t
131+
return runtime_params.SaturationRuntimeParams(
132+
sigmoid_width=0.0,
133+
sigmoid_offset=0.0,
134+
sigmoid_exponent=1.0,
135+
)
136+
137+
90138
class ProfileValueSaturation(torax_pydantic.BaseModelFrozen):
91139
"""Configuration for ProfileValueSaturation model.
92140
@@ -140,8 +188,8 @@ def build_runtime_params(
140188

141189

142190
# For new formation and saturation models, add to these TypeAliases via Union.
143-
FormationConfig: TypeAlias = MartinFormation
144-
SaturationConfig: TypeAlias = ProfileValueSaturation
191+
FormationConfig: TypeAlias = MartinFormation | NoFormation
192+
SaturationConfig: TypeAlias = ProfileValueSaturation | NoSaturation
145193

146194

147195
class BasePedestal(torax_pydantic.BaseModelFrozen, abc.ABC):
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2026 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""No pedestal saturation model."""
16+
17+
import dataclasses
18+
from torax._src import array_typing
19+
from torax._src import state
20+
from torax._src.config import runtime_params as runtime_params_lib
21+
from torax._src.geometry import geometry
22+
from torax._src.pedestal_model import pedestal_model_output
23+
from torax._src.pedestal_model.saturation import base
24+
25+
# pylint: disable=invalid-name
26+
27+
28+
@dataclasses.dataclass(frozen=True, eq=False)
29+
class NoSaturationModel(base.SaturationModel):
30+
"""No pedestal saturation model."""
31+
32+
def __call__(
33+
self,
34+
runtime_params: runtime_params_lib.RuntimeParams,
35+
geo: geometry.Geometry,
36+
core_profiles: state.CoreProfiles,
37+
pedestal_output: pedestal_model_output.PedestalModelOutput,
38+
) -> array_typing.FloatScalar:
39+
"""Returns no transport multipliers."""
40+
return pedestal_model_output.TransportMultipliers.default()

0 commit comments

Comments
 (0)