diff --git a/torax/plotting/configs/default_plot_config.py b/torax/plotting/configs/default_plot_config.py index a110d0fd..d65fb919 100644 --- a/torax/plotting/configs/default_plot_config.py +++ b/torax/plotting/configs/default_plot_config.py @@ -145,8 +145,8 @@ suppress_zero_values=True, # Do not plot all-zero data ), plotruns_lib.PlotProperties( - attrs=('q_brems',), - labels=(r'$Q_\mathrm{brems}$',), + attrs=('q_brems', 'q_imp'), + labels=(r'$Q_\mathrm{brems}$', r'$Q_\mathrm{imp}$'), ylabel=r'Heat sink density $[MW~m^{-3}]$', suppress_zero_values=True, # Do not plot all-zero data ), diff --git a/torax/plotting/configs/sources_plot_config.py b/torax/plotting/configs/sources_plot_config.py index 78227c0d..02aa0651 100644 --- a/torax/plotting/configs/sources_plot_config.py +++ b/torax/plotting/configs/sources_plot_config.py @@ -99,8 +99,8 @@ suppress_zero_values=True, # Do not plot all-zero data ), plotruns_lib.PlotProperties( - attrs=('q_brems',), - labels=(r'$Q_\mathrm{brems}$',), + attrs=('q_brems', 'q_imp'), + labels=(r'$Q_\mathrm{brems}$', r'$Q_\mathrm{imp}$'), ylabel=r'Heat sink density $[MW~m^{-3}]$', suppress_zero_values=True, # Do not plot all-zero data ), diff --git a/torax/plotting/plotruns_lib.py b/torax/plotting/plotruns_lib.py index 38448780..bf53ed60 100644 --- a/torax/plotting/plotruns_lib.py +++ b/torax/plotting/plotruns_lib.py @@ -111,6 +111,7 @@ class PlotData: q_ohmic: np.ndarray # [MW/m^3] q_brems: np.ndarray # [MW/m^3] q_ei: np.ndarray # [MW/m^3] + q_imp: np.ndarray # [MW/m^3] Q_fusion: np.ndarray # pylint: disable=invalid-name # Dimensionless s_puff: np.ndarray # [10^20 m^-3 s^-1] s_generic: np.ndarray # [10^20 m^-3 s^-1] @@ -171,12 +172,14 @@ def _transform_data(ds: xr.Dataset): 'fusion_heat_source_el': 1e6, # W/m^3 to MW/m^3 'ohmic_heat_source': 1e6, # W/m^3 to MW/m^3 'bremsstrahlung_heat_sink': 1e6, # W/m^3 to MW/m^3 + 'impurity_radiation_heat_sink': 1e6, # W/m^3 to MW/m^3 'qei_source': 1e6, # W/m^3 to MW/m^3 'P_ohmic': 1e6, # W to MW 'P_external_tot': 1e6, # W to MW 'P_alpha_tot': 1e6, # W to MW 'P_brems': 1e6, # W to MW 'P_ecrh': 1e6, # W to MW + 'P_imp': 1e6, # W to MW 'I_ecrh': 1e6, # A to MA 'I_generic': 1e6, # A to MA } @@ -246,6 +249,9 @@ def _transform_data(ds: xr.Dataset): q_brems=get_optional_data( core_sources_dataset, 'bremsstrahlung_heat_sink', 'cell' ), + q_imp=get_optional_data( + core_sources_dataset, 'impurity_radiation_heat_sink', 'cell' + ), q_ei=core_sources_dataset['qei_source'].to_numpy(), # ion heating/sink Q_fusion=post_processed_outputs_dataset['Q_fusion'].to_numpy(), # pylint: disable=invalid-name s_puff=get_optional_data(core_sources_dataset, 'gas_puff_source', 'cell'), @@ -263,7 +269,7 @@ def _transform_data(ds: xr.Dataset): - post_processed_outputs_dataset['P_ohmic'] ).to_numpy(), p_alpha=post_processed_outputs_dataset['P_alpha_tot'].to_numpy(), - p_sink=post_processed_outputs_dataset['P_brems'].to_numpy(), + p_sink=post_processed_outputs_dataset['P_brems'].to_numpy() + post_processed_outputs_dataset['P_imp'].to_numpy(), t=time, ) diff --git a/torax/post_processing.py b/torax/post_processing.py index 681e2fc8..30964c5e 100644 --- a/torax/post_processing.py +++ b/torax/post_processing.py @@ -40,6 +40,7 @@ 'ohmic_heat_source': 'P_ohmic', 'bremsstrahlung_heat_sink': 'P_brems', 'electron_cyclotron_source': 'P_ecrh', + 'impurity_radiation_heat_sink': 'P_imp', } EXTERNAL_HEATING_SOURCES = [ 'generic_ion_el_heat_source', diff --git a/torax/sources/impurity_radiation_heat_sink.py b/torax/sources/impurity_radiation_heat_sink.py new file mode 100644 index 00000000..bfcc876e --- /dev/null +++ b/torax/sources/impurity_radiation_heat_sink.py @@ -0,0 +1,128 @@ +"""Basic impurity radiation heat sink for electron heat equation..""" + +import dataclasses + +import chex +import jax +import jax.numpy as jnp + +from torax import array_typing +from torax import geometry +from torax import math_utils +from torax import state +from torax.config import runtime_params_slice +from torax.sources import runtime_params as runtime_params_lib +from torax.sources import source as source_lib +from torax.sources import source_models as source_models_lib + +SOURCE_NAME = "impurity_radiation_heat_sink" + + +def _radially_constant_fraction_of_Pin( + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + source_models: source_models_lib.SourceModels, +) -> jax.Array: + """Model function for radiation heat sink from impurities. + + This model represents a sink in the temp_el equation, whose value is a fixed + % of the total heating power input.""" + del ( + static_source_runtime_params, + ) # Unused + + # Based on source_models.sum_sources_temp_el and source_models.calc_and_sum_sources_psi, + # but only summing over heating *input* sources (Pohm + Paux + Palpha + ...) + # and summing over *both* ion and electron heating + + def get_heat_source_profile(source_name: str, source: source_lib.Source) -> jax.Array: + # TODO: Currently this recomputes the profile for each source, which is inefficient + # (and will be a problem if sources are slow/non-jittable) + # A similar TODO is noted in source_models.calc_and_sum_sources_psi + profile = source.get_value( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[source_name], + static_runtime_params_slice=static_runtime_params_slice, + static_source_runtime_params=static_runtime_params_slice.sources[source_name], + geo=geo, + core_profiles=core_profiles, + ) + return source.get_source_profile_for_affected_core_profile( + profile, source_lib.AffectedCoreProfile.TEMP_EL.value, geo + ) + source.get_source_profile_for_affected_core_profile( + profile, source_lib.AffectedCoreProfile.TEMP_ION.value, geo + ) + + # Calculate the total power input to the heat equations + heat_sources_and_sinks = source_models.temp_el_sources | source_models.temp_ion_sources + heat_sources = {k: v for k, v in heat_sources_and_sinks.items() if not "sink" in k} + source_profiles = jax.tree.map( + get_heat_source_profile, + list(heat_sources.keys()), + list(heat_sources.values()), + ) + Qtot_in = jnp.sum(jnp.stack(source_profiles), axis=0) + Ptot_in = math_utils.cell_integration(Qtot_in * geo.vpr, geo) + Vtot = geo.volume_face[-1] + + # Calculate the heat sink as a fraction of the total power input + return ( + -dynamic_source_runtime_params.fraction_of_total_power_density + * Ptot_in / Vtot + * jnp.ones_like(geo.rho) + ) + + +@dataclasses.dataclass(kw_only=True) +class RuntimeParams(runtime_params_lib.RuntimeParams): + fraction_of_total_power_density: runtime_params_lib.TimeInterpolatedInput = 0.1 + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED + + def make_provider( + self, + torax_mesh: geometry.Grid1D | None = None, + ) -> "RuntimeParamsProvider": + return RuntimeParamsProvider(**self.get_provider_kwargs(torax_mesh)) + + +@chex.dataclass +class RuntimeParamsProvider(runtime_params_lib.RuntimeParamsProvider): + """Provides runtime parameters for a given time and geometry.""" + + runtime_params_config: RuntimeParams + + def build_dynamic_params( + self, + t: chex.Numeric, + ) -> "DynamicRuntimeParams": + return DynamicRuntimeParams(**self.get_dynamic_params_kwargs(t)) + + +@chex.dataclass(frozen=True) +class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): + fraction_of_total_power_density: array_typing.ScalarFloat + + +@dataclasses.dataclass(kw_only=True, frozen=True, eq=True) +class ImpurityRadiationHeatSink(source_lib.Source): + """Impurity radiation heat sink for electron heat equation.""" + + source_models: source_models_lib.SourceModels + model_func: source_lib.SourceProfileFunction = _radially_constant_fraction_of_Pin + + @property + def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: + """Returns the modes supported by this source.""" + return ( + runtime_params_lib.Mode.ZERO, + runtime_params_lib.Mode.MODEL_BASED, + runtime_params_lib.Mode.PRESCRIBED, + ) + + @property + def affected_core_profiles(self) -> tuple[source_lib.AffectedCoreProfile, ...]: + return (source_lib.AffectedCoreProfile.TEMP_EL,) diff --git a/torax/sources/register_source.py b/torax/sources/register_source.py index 11085e2b..2614d1db 100644 --- a/torax/sources/register_source.py +++ b/torax/sources/register_source.py @@ -45,6 +45,7 @@ class to build, the runtime associated with that source and (optionally) the from torax.sources import fusion_heat_source from torax.sources import generic_current_source from torax.sources import generic_ion_el_heat_source as ion_el_heat +from torax.sources import impurity_radiation_heat_sink from torax.sources import ion_cyclotron_source from torax.sources import ohmic_heat_source from torax.sources import qei_source @@ -146,6 +147,11 @@ def _register_new_source( source_class=ion_cyclotron_source.IonCyclotronSource, default_runtime_params_class=ion_cyclotron_source.RuntimeParams, ), + impurity_radiation_heat_sink.SOURCE_NAME: _register_new_source( + source_class=impurity_radiation_heat_sink.ImpurityRadiationHeatSink, + default_runtime_params_class=impurity_radiation_heat_sink.RuntimeParams, + links_back=True, + ) } @@ -155,4 +161,3 @@ def get_registered_source(source_name: str) -> RegisteredSource: return _REGISTERED_SOURCES[source_name] else: raise RuntimeError(f'Source:{source_name} has not been registered.') - diff --git a/torax/sources/tests/impurity_radiation_heat_sink.py b/torax/sources/tests/impurity_radiation_heat_sink.py new file mode 100644 index 00000000..86e24208 --- /dev/null +++ b/torax/sources/tests/impurity_radiation_heat_sink.py @@ -0,0 +1,226 @@ +"""Tests for impurity_radiation_heat_sink.""" + +import chex +import jax.numpy as jnp +import numpy as np +from absl.testing import absltest + +from torax import core_profile_setters +from torax import geometry +from torax import math_utils +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice +from torax.sources import generic_ion_el_heat_source +from torax.sources import ( + impurity_radiation_heat_sink as impurity_radiation_heat_sink_lib, +) +from torax.sources import runtime_params as runtime_params_lib +from torax.sources import source as source_lib +from torax.sources import source_models as source_models_lib +from torax.sources.tests import test_lib + + +class ImpurityRadiationHeatSinkTest(test_lib.SourceTestCase): + """Tests for ImpurityRadiationHeatSink.""" + + @classmethod + def setUpClass(cls): + super().setUpClass( + source_class=impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink, + runtime_params_class=impurity_radiation_heat_sink_lib.RuntimeParams, + unsupported_modes=[ + runtime_params_lib.Mode.MODEL_BASED, + runtime_params_lib.Mode.PRESCRIBED, + ], + links_back=True, + ) + + def test_source_value(self): + """Tests that the source value is correct.""" + # Source builder for this class + impurity_radiation_sink_builder = self._source_class_builder() + impurity_radiation_sink_builder.runtime_params.mode = ( + runtime_params_lib.Mode.MODEL_BASED + ) + if not source_lib.is_source_builder(impurity_radiation_sink_builder): + raise TypeError(f"{type(self)} has a bad _source_class_builder") + + # Source builder for generic_ion_el_heat_source + # We don't test this class, as that should be done in its own test + heat_source_builder_builder = source_lib.make_source_builder( + source_type=generic_ion_el_heat_source.GenericIonElectronHeatSource, + runtime_params_type=generic_ion_el_heat_source.RuntimeParams, + ) + heat_source_builder = heat_source_builder_builder() + + # Runtime params + runtime_params = general_runtime_params.GeneralRuntimeParams() + + # Source models + source_models_builder = source_models_lib.SourceModelsBuilder( + { + impurity_radiation_heat_sink_lib.SOURCE_NAME: impurity_radiation_sink_builder, + generic_ion_el_heat_source.SOURCE_NAME: heat_source_builder, + }, + ) + source_models = source_models_builder() + + # Extract the source we're testing and check that it's been built correctly + impurity_radiation_sink = source_models.sources[ + impurity_radiation_heat_sink_lib.SOURCE_NAME + ] + self.assertIsInstance(impurity_radiation_sink, source_lib.Source) + + # Geometry, profiles, and dynamic runtime params + geo = geometry.build_circular_geometry() + dynamic_runtime_params_slice = ( + runtime_params_slice.DynamicRuntimeParamsSliceProvider( + runtime_params=runtime_params, + sources=source_models_builder.runtime_params, + torax_mesh=geo.torax_mesh, + )( + t=runtime_params.numerics.t_initial, + ) + ) + static_slice = runtime_params_slice.build_static_runtime_params_slice( + runtime_params, + source_runtime_params=source_models_builder.runtime_params, + ) + core_profiles = core_profile_setters.initial_core_profiles( + static_runtime_params_slice=static_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + geo=geo, + source_models=source_models, + ) + impurity_radiation_sink_dynamic_runtime_params_slice = ( + dynamic_runtime_params_slice.sources[ + impurity_radiation_heat_sink_lib.SOURCE_NAME + ] + ) + impurity_radiation_sink_static_runtime_params_slice = static_slice.sources[ + impurity_radiation_heat_sink_lib.SOURCE_NAME + ] + heat_source_dynamic_runtime_params_slice = dynamic_runtime_params_slice.sources[ + generic_ion_el_heat_source.SOURCE_NAME + ] + impurity_radiation_heat_sink_power_density = impurity_radiation_sink.get_value( + static_runtime_params_slice=static_slice, + static_source_runtime_params=impurity_radiation_sink_static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=impurity_radiation_sink_dynamic_runtime_params_slice, + geo=geo, + core_profiles=core_profiles, + ) + + # ImpurityRadiationHeatSink provides TEMP_EL only + chex.assert_rank(impurity_radiation_heat_sink_power_density, 1) + + # The value should be equal to fraction * sum of the (TEMP_EL+TEMP_ION) + # sources, minus P_ei and P_brems. + # In this case, that is only the generic_ion_el_heat_source. + impurity_radiation_heat_sink_power = math_utils.cell_integration( + impurity_radiation_heat_sink_power_density * geo.vpr, geo + ) + chex.assert_trees_all_close( + impurity_radiation_heat_sink_power, + heat_source_dynamic_runtime_params_slice.Ptot + * -impurity_radiation_sink_dynamic_runtime_params_slice.fraction_of_total_power_density, + rtol=1e-2, # TODO: this rtol seems v. high + ) + + def test_invalid_source_types_raise_errors(self): + """Tests that using unsupported types raises an error.""" + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry() + source_builder = self._source_class_builder() + source_models_builder = source_models_lib.SourceModelsBuilder( + {"foo": source_builder}, + ) + source_models = source_models_builder() + source = source_models.sources["foo"] + self.assertIsInstance(source, source_lib.Source) + dynamic_runtime_params_slice_provider = ( + runtime_params_slice.DynamicRuntimeParamsSliceProvider( + runtime_params=runtime_params, + sources=source_models_builder.runtime_params, + torax_mesh=geo.torax_mesh, + ) + ) + # This slice is needed to create the core_profiles + dynamic_runtime_params_slice = dynamic_runtime_params_slice_provider( + t=runtime_params.numerics.t_initial, + ) + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice( + runtime_params, + source_runtime_params=source_models_builder.runtime_params, + ) + ) + core_profiles = core_profile_setters.initial_core_profiles( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_runtime_params_slice, + geo=geo, + source_models=source_models, + ) + + for unsupported_mode in self._unsupported_modes: + source_builder.runtime_params.mode = unsupported_mode + # Construct a new slice with the given mode + dynamic_runtime_params_slice = ( + runtime_params_slice.DynamicRuntimeParamsSliceProvider( + runtime_params=runtime_params, + sources=source_models_builder.runtime_params, + torax_mesh=geo.torax_mesh, + )( + t=runtime_params.numerics.t_initial, + ) + ) + with self.subTest(unsupported_mode.name): + with self.assertRaises(RuntimeError): + source.get_value( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ + "foo" + ], + static_runtime_params_slice=static_runtime_params_slice, + static_source_runtime_params=static_runtime_params_slice.sources[ + "foo" + ], + geo=geo, + core_profiles=core_profiles, + ) + + def test_extraction_of_relevant_profile_from_output(self): + """Tests that the relevant profile is extracted from the output.""" + geo = geometry.build_circular_geometry() + source_builder = self._source_class_builder() + source_models_builder = source_models_lib.SourceModelsBuilder( + {"foo": source_builder}, + ) + source_models = source_models_builder() + source = source_models.sources["foo"] + self.assertIsInstance(source, source_lib.Source) + cell = source_lib.ProfileType.CELL.get_profile_shape(geo) + fake_profile = jnp.ones(cell) + # Check TEMP_EL is modified + np.testing.assert_allclose( + source.get_source_profile_for_affected_core_profile( + fake_profile, + source_lib.AffectedCoreProfile.TEMP_EL.value, + geo, + ), + jnp.ones(cell), + ) + # For unrelated states, this should just return all 0s. + np.testing.assert_allclose( + source.get_source_profile_for_affected_core_profile( + fake_profile, + source_lib.AffectedCoreProfile.NE.value, + geo, + ), + jnp.zeros(cell), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/torax/state.py b/torax/state.py index e518e57b..d4ca8d3f 100644 --- a/torax/state.py +++ b/torax/state.py @@ -315,6 +315,7 @@ class PostProcessedOutputs: P_ohmic: Ohmic heating power to electrons [W] P_brems: Bremsstrahlung electron heat sink [W] P_ecrh: Total electron cyclotron source power [W] + P_imp: Impurity radiation heat sink [W] I_ecrh: Total electron cyclotron source current [A] I_generic: Total generic source current [A] Q_fusion: Fusion power gain @@ -364,6 +365,7 @@ class PostProcessedOutputs: P_ohmic: array_typing.ScalarFloat P_brems: array_typing.ScalarFloat P_ecrh: array_typing.ScalarFloat + P_imp: array_typing.ScalarFloat I_ecrh: array_typing.ScalarFloat I_generic: array_typing.ScalarFloat Q_fusion: array_typing.ScalarFloat @@ -412,6 +414,7 @@ def zeros(cls, geo: geometry.Geometry) -> PostProcessedOutputs: P_ohmic=jnp.array(0.0), P_brems=jnp.array(0.0), P_ecrh=jnp.array(0.0), + P_imp=jnp.array(0.0), I_ecrh=jnp.array(0.0), I_generic=jnp.array(0.0), Q_fusion=jnp.array(0.0), diff --git a/torax/tests/sim.py b/torax/tests/sim.py index 4f8acf3f..03569456 100644 --- a/torax/tests/sim.py +++ b/torax/tests/sim.py @@ -381,6 +381,13 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Predictor-corrector solver with simple impurity radiation + ( + 'test_iterhybrid_predictor_corrector_impurity_radiation', + 'test_iterhybrid_predictor_corrector_impurity_radiation.py', + _ALL_PROFILES, + 0, + ), # Tests Newton-Raphson nonlinear solver for ITER-hybrid-like-config ( 'test_iterhybrid_newton', diff --git a/torax/tests/test_data/test_changing_config_after.nc b/torax/tests/test_data/test_changing_config_after.nc index ccb0385c..a329be7f 100644 Binary files a/torax/tests/test_data/test_changing_config_after.nc and b/torax/tests/test_data/test_changing_config_after.nc differ diff --git a/torax/tests/test_data/test_changing_config_before.nc b/torax/tests/test_data/test_changing_config_before.nc index 9e039679..74eb5459 100644 Binary files a/torax/tests/test_data/test_changing_config_before.nc and b/torax/tests/test_data/test_changing_config_before.nc differ diff --git a/torax/tests/test_data/test_implicit.nc b/torax/tests/test_data/test_implicit.nc index 234ecf38..aad4e4f6 100644 Binary files a/torax/tests/test_data/test_implicit.nc and b/torax/tests/test_data/test_implicit.nc differ diff --git a/torax/tests/test_data/test_iterhybrid_predictor_corrector_impurity_radiation.nc b/torax/tests/test_data/test_iterhybrid_predictor_corrector_impurity_radiation.nc new file mode 100644 index 00000000..46b7706b Binary files /dev/null and b/torax/tests/test_data/test_iterhybrid_predictor_corrector_impurity_radiation.nc differ diff --git a/torax/tests/test_data/test_iterhybrid_predictor_corrector_impurity_radiation.py b/torax/tests/test_data/test_iterhybrid_predictor_corrector_impurity_radiation.py new file mode 100644 index 00000000..38b3483a --- /dev/null +++ b/torax/tests/test_data/test_iterhybrid_predictor_corrector_impurity_radiation.py @@ -0,0 +1,23 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Identical to test_iterhybrid_predictor_corrector but simple impurity radiation.""" + +import copy + +from torax.tests.test_data import test_iterhybrid_predictor_corrector + +CONFIG = copy.deepcopy(test_iterhybrid_predictor_corrector.CONFIG) + +CONFIG['sources']['impurity_radiation_heat_sink'] = {}