diff --git a/openmc/source.py b/openmc/source.py index a77ca8b04c0..5deb6b05f50 100644 --- a/openmc/source.py +++ b/openmc/source.py @@ -399,20 +399,19 @@ class MeshSource(SourceBase): ---------- mesh : openmc.MeshBase The mesh over which source sites will be generated. - sources : iterable of openmc.SourceBase - Sources for each element in the mesh. If spatial distributions are set - on any of the source objects, they will be ignored during source site - sampling. + sources : sequence of openmc.SourceBase + Sources for each element in the mesh. Sources must be specified as + either a 1-D array in the order of the mesh indices or a + multidimensional array whose shape matches the mesh shape. If spatial + distributions are set on any of the source objects, they will be ignored + during source site sampling. Attributes ---------- mesh : openmc.MeshBase The mesh over which source sites will be generated. - sources : numpy.ndarray or iterable of openmc.SourceBase - The set of sources to apply to each element. The shape of this array - must match the shape of the mesh with and exception in the case of - unstructured mesh, which allows for application of 1-D array or - iterable. + sources : numpy.ndarray of openmc.SourceBase + Sources to apply to each element strength : float Strength of the source type : str @@ -433,7 +432,7 @@ def mesh(self) -> MeshBase: @property def strength(self) -> float: - return sum(s.strength for s in self.sources.flat) + return sum(s.strength for s in self.sources) @property def sources(self) -> np.ndarray: @@ -450,16 +449,23 @@ def sources(self, s): s = np.asarray(s) - if isinstance(self.mesh, StructuredMesh) and s.shape != self.mesh.dimension: - raise ValueError('The shape of the source array' - f'({s.shape}) does not match the ' - f'dimensions of the structured mesh ({self.mesh.dimension})') + if isinstance(self.mesh, StructuredMesh): + if s.size != self.mesh.num_mesh_cells: + raise ValueError( + f'The length of the source array ({s.size}) does not match ' + f'the number of mesh elements ({self.mesh.num_mesh_cells}).') + + # If user gave a multidimensional array, flatten in the order + # of the mesh indices + if s.ndim > 1: + s = s.ravel(order='F') + elif isinstance(self.mesh, UnstructuredMesh): - if len(s.shape) > 1: + if s.ndim > 1: raise ValueError('Sources must be a 1-D array for unstructured mesh') self._sources = s - for src in self._sources.flat: + for src in self._sources: if isinstance(src, IndependentSource) and src.space is not None: warnings.warn('Some sources on the mesh have spatial ' 'distributions that will be ignored at runtime.') @@ -481,7 +487,7 @@ def set_total_strength(self, strength: float): """ current_strength = self.strength if self.strength != 0.0 else 1.0 - for s in self.sources.flat: + for s in self.sources: s.strength *= strength / current_strength def normalize_source_strengths(self): @@ -500,13 +506,8 @@ def populate_xml_element(self, elem: ET.Element): elem.set("mesh", str(self.mesh.id)) # write in the order of mesh indices - if isinstance(self.mesh, openmc.UnstructuredMesh): - for s in self.sources: - elem.append(s.to_xml_element()) - else: - for idx in self.mesh.indices: - idx = tuple(i - 1 for i in idx) - elem.append(self.sources[idx].to_xml_element()) + for s in self.sources: + elem.append(s.to_xml_element()) @classmethod def from_xml_element(cls, elem: ET.Element, meshes) -> openmc.MeshSource: @@ -527,11 +528,9 @@ def from_xml_element(cls, elem: ET.Element, meshes) -> openmc.MeshSource: MeshSource generated from the XML element """ mesh_id = int(get_text(elem, 'mesh')) - mesh = meshes[mesh_id] sources = [SourceBase.from_xml_element(e) for e in elem.iterchildren('source')] - sources = np.asarray(sources).reshape(mesh.dimension, order='F') return cls(mesh, sources) diff --git a/tests/unit_tests/test_source_mesh.py b/tests/unit_tests/test_source_mesh.py index f0813d2f9e0..43bb1678c40 100644 --- a/tests/unit_tests/test_source_mesh.py +++ b/tests/unit_tests/test_source_mesh.py @@ -276,12 +276,12 @@ def test_mesh_source_independent(run_in_tmpdir, void_model, mesh_type): # for each element, set a single-non zero source with particles # traveling out of the mesh (and geometry) w/o crossing any other # mesh elements - for i, j, k in mesh.indices: + for flat_index, (i, j, k) in enumerate(mesh.indices): ijk = (i-1, j-1, k-1) # zero-out all source strengths and set the strength # on the element of interest mesh_source.strength = 0.0 - mesh_source.sources[ijk].strength = 1.0 + mesh_source.sources[flat_index].strength = 1.0 sp_file = model.run() @@ -375,10 +375,7 @@ def test_mesh_source_file(run_in_tmpdir): mesh.upper_right = (2, 3, 4) mesh.dimension = (1, 1, 1) - mesh_source_arr = np.asarray([file_source]).reshape(mesh.dimension) - source = openmc.MeshSource(mesh, mesh_source_arr) - - model.settings.source = source + model.settings.source = openmc.MeshSource(mesh, [file_source]) model.export_to_model_xml()