-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add simple impurity radiation model (P_imp as a fraction of P_tot_in) #572
Open
theo-brown
wants to merge
8
commits into
main
Choose a base branch
from
theo-brown/add-prad
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+425
−18
Open
Changes from 6 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
a005f07
Add radiation_heat_sink, proportional to Qtot
theo-brown ff0c854
Add Prad to post_processing and plots
theo-brown cc16cce
Support sources with links_back=True
theo-brown 7a0936a
Add unit test for RadiationHeatSink
theo-brown 601cde4
Merge branch 'main' into theo-brown/add-prad
theo-brown 65e5836
Sum over both ion and el heat sources
theo-brown c6710b6
Switch to radially constant fraction of Pin
theo-brown 0d242e4
Minor changes
theo-brown File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
CONFIG = { | ||
'runtime_params': {}, | ||
'geometry': { | ||
'geometry_type': 'circular', | ||
}, | ||
'sources': { | ||
'j_bootstrap': {}, | ||
'generic_current_source': {}, | ||
'generic_particle_source': {}, | ||
'gas_puff_source': {}, | ||
'pellet_source': {}, | ||
'generic_ion_el_heat_source': {}, | ||
'fusion_heat_source': {}, | ||
'qei_source': {}, | ||
'ohmic_heat_source': {}, | ||
'radiation_heat_sink': {}, | ||
}, | ||
'transport': { | ||
'transport_model': 'constant', | ||
}, | ||
'stepper': { | ||
'stepper_type': 'linear', | ||
}, | ||
'time_step_calculator': { | ||
'calculator_type': 'chi', | ||
}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""Basic radiation heat sink for electron heat equation..""" | ||
|
||
import dataclasses | ||
|
||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from torax import array_typing | ||
from torax import geometry | ||
from torax import state | ||
from torax.config import runtime_params_slice | ||
from torax.sources import bremsstrahlung_heat_sink | ||
from torax.sources import qei_source | ||
from torax.sources import runtime_params as runtime_params_lib | ||
from torax.sources import source as source_lib | ||
from torax.sources import source_models as source_models_lib | ||
|
||
SOURCE_NAME = "radiation_heat_sink" | ||
|
||
|
||
def _Qrad_as_fraction_of_Qtot_in( | ||
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, | ||
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, | ||
geo: geometry.Geometry, | ||
core_profiles: state.CoreProfiles, | ||
source_models: source_models_lib.SourceModels, | ||
) -> jax.Array: | ||
"""Model function for radiation heat sink. | ||
|
||
In this model, a fixed % of the total power input to the temp_el equation is lost.""" | ||
theo-brown marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Based on source_models.sum_sources_temp_el and source_models.calc_and_sum_sources_psi, | ||
# but only summing over heating *input* sources (Pohm + Paux + Palpha) | ||
# and summing over *both* ion and electron heating | ||
|
||
def get_temp_el_profile(source_name: str, source: source_lib.Source) -> jax.Array: | ||
# TODO: Currently this recomputes the profile for each source, which is inefficient | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree. We will prioritize changing the pattern such that this won't happen. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK for the time being though? |
||
# (and will be a problem if sources are slow/non-jittable) | ||
# A similar TODO is noted in source_models.calc_and_sum_sources_psi | ||
profile = source.get_value( | ||
dynamic_runtime_params_slice, | ||
dynamic_runtime_params_slice.sources[source_name], | ||
geo, | ||
core_profiles, | ||
) | ||
return source.get_source_profile_for_affected_core_profile( | ||
profile, source_lib.AffectedCoreProfile.TEMP_EL.value, geo | ||
) + source.get_source_profile_for_affected_core_profile( | ||
profile, source_lib.AffectedCoreProfile.TEMP_ION.value, geo | ||
) | ||
|
||
# Manually remove sources that will not be summed | ||
sources_to_sum = source_models.temp_el_sources | source_models.temp_ion_sources | ||
sources_to_sum.pop(SOURCE_NAME, None) | ||
sources_to_sum.pop(bremsstrahlung_heat_sink.SOURCE_NAME, None) | ||
theo-brown marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sources_to_sum.pop(qei_source.SOURCE_NAME, None) | ||
|
||
source_profiles = jax.tree.map( | ||
get_temp_el_profile, | ||
list(sources_to_sum.keys()), | ||
list(sources_to_sum.values()) | ||
) | ||
|
||
Qtot_in = jnp.sum(jnp.stack(source_profiles), axis=0) | ||
|
||
# Calculate the radiation heat sink | ||
return ( | ||
-dynamic_source_runtime_params.fraction_of_total_power_density | ||
* Qtot_in | ||
* jnp.ones_like(geo.rho) | ||
) | ||
|
||
|
||
@dataclasses.dataclass(kw_only=True) | ||
class RuntimeParams(runtime_params_lib.RuntimeParams): | ||
fraction_of_total_power_density: runtime_params_lib.TimeInterpolatedInput = 0.1 | ||
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED | ||
|
||
def make_provider( | ||
self, | ||
torax_mesh: geometry.Grid1D | None = None, | ||
) -> "RuntimeParamsProvider": | ||
return RuntimeParamsProvider(**self.get_provider_kwargs(torax_mesh)) | ||
|
||
|
||
@chex.dataclass | ||
class RuntimeParamsProvider(runtime_params_lib.RuntimeParamsProvider): | ||
"""Provides runtime parameters for a given time and geometry.""" | ||
|
||
runtime_params_config: RuntimeParams | ||
|
||
def build_dynamic_params( | ||
self, | ||
t: chex.Numeric, | ||
) -> "DynamicRuntimeParams": | ||
return DynamicRuntimeParams(**self.get_dynamic_params_kwargs(t)) | ||
|
||
|
||
@chex.dataclass(frozen=True) | ||
class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): | ||
fraction_of_total_power_density: array_typing.ScalarFloat | ||
|
||
|
||
@dataclasses.dataclass(kw_only=True, frozen=True, eq=True) | ||
class RadiationHeatSink(source_lib.Source): | ||
"""Radiation heat sink for electron heat equation.""" | ||
|
||
source_models: source_models_lib.SourceModels | ||
model_func: source_lib.SourceProfileFunction = _Qrad_as_fraction_of_Qtot_in | ||
|
||
@property | ||
def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: | ||
"""Returns the modes supported by this source.""" | ||
return ( | ||
runtime_params_lib.Mode.ZERO, | ||
runtime_params_lib.Mode.MODEL_BASED, | ||
runtime_params_lib.Mode.PRESCRIBED, | ||
) | ||
|
||
@property | ||
def affected_core_profiles(self) -> tuple[source_lib.AffectedCoreProfile, ...]: | ||
return (source_lib.AffectedCoreProfile.TEMP_EL,) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure there's much point in making an example config out of a single feature?
Our standard pattern is to make a new integration test in tests/tests_data with new features for regression and integration testing.
Suggest to remove it here, make a new integration test (to add to tests/sim.py also) based on one of the standard cases like test_iterhybrid_predictor_corrector
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll definitely do this. I hadn't actually meant to leave this config in, I'd forgotten that I hadn't got round to turning it into an integration test yet!