From 17374dd8c8d5112adc134ce117ac05a17fddbf71 Mon Sep 17 00:00:00 2001 From: Martin Vonk Date: Thu, 14 Nov 2024 10:25:27 +0100 Subject: [PATCH 1/3] allow function to be parse to add_py_function --- pyemu/utils/pst_from.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/pyemu/utils/pst_from.py b/pyemu/utils/pst_from.py index d113cebd..dfcdd9c3 100644 --- a/pyemu/utils/pst_from.py +++ b/pyemu/utils/pst_from.py @@ -3,6 +3,10 @@ from pathlib import Path import warnings import platform +import string +import warnings +from pathlib import Path +from typing import Callable, Union import numpy as np import pandas as pd import pyemu @@ -1183,13 +1187,16 @@ def _next_count(self, prefix): return self._prefix_count[prefix] def add_py_function( - self, file_name, call_str=None, is_pre_cmd=True, - function_name=None + self, + file_name: Union[str, Callable], + call_str: Union[None, str] = None, + is_pre_cmd: Union[bool, None] = True, + function_name=None, ): """add a python function to the forward run script Args: - file_name (`str`): a python source file + file_name (`str` or `callable`): a python source file or function/callable call_str (`str`): the call string for python function in `file_name`. `call_str` will be added to the forward run script, as is. @@ -1209,7 +1216,7 @@ def add_py_function( `PstFrom.extra_py_imports` list. This function adds the `call_str` call to the forward - run script (either as a pre or post command or function not + run script (either as a pre or post command or function not directly called by main). It is up to users to make sure `call_str` is a valid python function call that includes the parentheses and requisite arguments @@ -1245,12 +1252,6 @@ def add_py_function( self.logger.lraise( "add_py_function(): No function call string passed in arg " "'call_str'" ) - if not os.path.exists(file_name): - self.logger.lraise( - "add_py_function(): couldnt find python source file '{0}'".format( - file_name - ) - ) if "(" not in call_str or ")" not in call_str: self.logger.lraise( "add_py_function(): call_str '{0}' missing paretheses".format(call_str) @@ -1266,7 +1267,16 @@ def add_py_function( f"original will be maintained", PyemuWarning, ) + if callable(file_name): + func_lines = getsource(file_name).splitlines(keepends=True) + self._function_lines_list.append(func_lines) else: + if not os.path.exists(file_name): + self.logger.lraise( + "add_py_function(): couldnt find python source file '{0}'".format( + file_name + ) + ) func_lines = [] search_str = "def " + function_name + "(" abet_set = set(string.ascii_uppercase) From 996520986cdd968f8026b4dc1a2d6494d460de27 Mon Sep 17 00:00:00 2001 From: Martin Vonk Date: Thu, 14 Nov 2024 10:27:13 +0100 Subject: [PATCH 2/3] update imports --- pyemu/utils/pst_from.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pyemu/utils/pst_from.py b/pyemu/utils/pst_from.py index dfcdd9c3..792c4f50 100644 --- a/pyemu/utils/pst_from.py +++ b/pyemu/utils/pst_from.py @@ -1,21 +1,22 @@ -from __future__ import print_function, division +from __future__ import division, print_function + +import copy import os -from pathlib import Path -import warnings import platform import string import warnings +from inspect import getsource from pathlib import Path from typing import Callable, Union + import numpy as np import pandas as pd -import pyemu -from ..pyemu_warnings import PyemuWarning -import copy -import string +import pyemu from pyemu.utils.helpers import _try_pdcol_numeric +from ..pyemu_warnings import PyemuWarning + # the tolerable percent difference (100 * (max - min)/mean) # used when checking that constant and zone type parameters are in fact constant (within # a given zone) From 8a522266af42947dcb068c50bd04264cf2fd8416 Mon Sep 17 00:00:00 2001 From: Martin Vonk Date: Fri, 15 Nov 2024 09:39:33 +0100 Subject: [PATCH 3/3] test add_py_function with callable --- autotest/pst_from_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autotest/pst_from_tests.py b/autotest/pst_from_tests.py index 93c7827b..6b2eba78 100644 --- a/autotest/pst_from_tests.py +++ b/autotest/pst_from_tests.py @@ -264,7 +264,7 @@ def freyberg_test(tmp_path): # (generated by pyemu.gw_utils.setup_hds_obs()) f, fdf = _gen_dummy_obs_file(pf.new_d) pf.add_observations(f, index_cols='idx', use_cols='yes') - pf.add_py_function(__file__, '_gen_dummy_obs_file()', + pf.add_py_function(_gen_dummy_obs_file, '_gen_dummy_obs_file()', is_pre_cmd=False) pf.add_observations('freyberg.hds.dat', insfile='freyberg.hds.dat.ins2', index_cols='obsnme', use_cols='obsval', prefix='hds')