diff --git a/numpyro/infer/inspect.py b/numpyro/infer/inspect.py index abfb9bd71..6b8d5d058 100644 --- a/numpyro/infer/inspect.py +++ b/numpyro/infer/inspect.py @@ -629,9 +629,14 @@ def render_model( if filename is not None: filename = Path(filename) + # remove leading period from suffix + filename_without_suffix = filename.with_suffix("") graph.render( - filename.stem, view=False, cleanup=True, format=filename.suffix[1:] - ) # remove leading period from suffix + filename_without_suffix, + view=False, + cleanup=True, + format=filename.suffix[1:], + ) return graph diff --git a/test/test_model_rendering.py b/test/test_model_rendering.py index 62b7cf65f..a543add90 100644 --- a/test/test_model_rendering.py +++ b/test/test_model_rendering.py @@ -1,6 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import os + import numpy as np import pytest @@ -8,7 +10,11 @@ import numpyro import numpyro.distributions as dist -from numpyro.infer.inspect import generate_graph_specification, get_model_relations +from numpyro.infer.inspect import ( + generate_graph_specification, + get_model_relations, + render_model, +) def simple(data): @@ -129,3 +135,12 @@ def test_model_transformation(test_model, model_kwargs, expected_graph_spec): graph_spec = generate_graph_specification(relations) assert graph_spec == expected_graph_spec + + +def test_render_model_filename(): + def model(): + numpyro.sample("x", dist.Normal(0, 1)) + + render_model(model, filename="graph.png") + assert os.path.exists("graph.png") + os.remove("graph.png")