Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow parsing function to add_py_function #549

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autotest/pst_from_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def freyberg_test(tmp_path):
# (generated by pyemu.gw_utils.setup_hds_obs())
f, fdf = _gen_dummy_obs_file(pf.new_d)
pf.add_observations(f, index_cols='idx', use_cols='yes')
pf.add_py_function(__file__, '_gen_dummy_obs_file()',
pf.add_py_function(_gen_dummy_obs_file, '_gen_dummy_obs_file()',
is_pre_cmd=False)
pf.add_observations('freyberg.hds.dat', insfile='freyberg.hds.dat.ins2',
index_cols='obsnme', use_cols='obsval', prefix='hds')
Expand Down
45 changes: 28 additions & 17 deletions pyemu/utils/pst_from.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from __future__ import print_function, division
from __future__ import division, print_function

import copy
import os
from pathlib import Path
import warnings
import platform
import string
import warnings
from inspect import getsource
from pathlib import Path
from typing import Callable, Union

import numpy as np
import pandas as pd
import pyemu
from ..pyemu_warnings import PyemuWarning
import copy
import string

import pyemu
from pyemu.utils.helpers import _try_pdcol_numeric

from ..pyemu_warnings import PyemuWarning

# the tolerable percent difference (100 * (max - min)/mean)
# used when checking that constant and zone type parameters are in fact constant (within
# a given zone)
Expand Down Expand Up @@ -1183,13 +1188,16 @@ def _next_count(self, prefix):
return self._prefix_count[prefix]

def add_py_function(
self, file_name, call_str=None, is_pre_cmd=True,
function_name=None
self,
file_name: Union[str, Callable],
call_str: Union[None, str] = None,
is_pre_cmd: Union[bool, None] = True,
function_name=None,
):
"""add a python function to the forward run script

Args:
file_name (`str`): a python source file
file_name (`str` or `callable`): a python source file or function/callable
call_str (`str`): the call string for python function in
`file_name`.
`call_str` will be added to the forward run script, as is.
Expand All @@ -1209,7 +1217,7 @@ def add_py_function(
`PstFrom.extra_py_imports` list.

This function adds the `call_str` call to the forward
run script (either as a pre or post command or function not
run script (either as a pre or post command or function not
directly called by main). It is up to users
to make sure `call_str` is a valid python function call
that includes the parentheses and requisite arguments
Expand Down Expand Up @@ -1245,12 +1253,6 @@ def add_py_function(
self.logger.lraise(
"add_py_function(): No function call string passed in arg " "'call_str'"
)
if not os.path.exists(file_name):
self.logger.lraise(
"add_py_function(): couldnt find python source file '{0}'".format(
file_name
)
)
if "(" not in call_str or ")" not in call_str:
self.logger.lraise(
"add_py_function(): call_str '{0}' missing paretheses".format(call_str)
Expand All @@ -1266,7 +1268,16 @@ def add_py_function(
f"original will be maintained",
PyemuWarning,
)
if callable(file_name):
func_lines = getsource(file_name).splitlines(keepends=True)
self._function_lines_list.append(func_lines)
else:
if not os.path.exists(file_name):
self.logger.lraise(
"add_py_function(): couldnt find python source file '{0}'".format(
file_name
)
)
func_lines = []
search_str = "def " + function_name + "("
abet_set = set(string.ascii_uppercase)
Expand Down