Skip to content

Commit

Permalink
Merge branch 'main' into feature/pydantic-2-compatible-type-signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
jstvz authored Apr 8, 2024
2 parents 4b35008 + 9baecf5 commit 384cc2a
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 30 deletions.
2 changes: 1 addition & 1 deletion snowfakery/data_generator_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def __exit__(self, *args):
try:
plugin.close()
except Exception as e:
warn(f"Could not close {plugin} because {e}")
warn(f"Could not close {plugin} because {repr(e)}")
self.current_context = None
self.plugin_instances = None
self.plugin_function_libraries = None
Expand Down
10 changes: 5 additions & 5 deletions snowfakery/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def __init__(self, interpreter):
def custom_functions(self, *args, **kwargs):
"""Instantiate, contextualize and return a function library
Default behaviour is to return self.Function()."""
functions = self.Functions()
Default behaviour is to return self.Functions()."""
functions = self.Functions() # type: ignore
functions.context = self.context
return functions

Expand Down Expand Up @@ -200,7 +200,7 @@ def resolve_plugins(
with plugin_path(search_paths):
plugins = []
for plugin_spec in plugin_specs:
plugins.extend(resolve_plugin(*plugin_spec))
plugins.extend(resolve_plugin(*plugin_spec)) # type: ignore
return plugins


Expand Down Expand Up @@ -312,7 +312,7 @@ def __init_subclass__(cls, **kwargs):


def _register_for_continuation(cls):
SnowfakeryDumper.add_representer(cls, Representer.represent_object)
SnowfakeryDumper.add_representer(cls, Representer.represent_object) # type: ignore
yaml.SafeLoader.add_constructor(
f"tag:yaml.org,2002:python/object/apply:{cls.__module__}.{cls.__name__}",
lambda loader, node: cls._from_continuation(
Expand All @@ -327,7 +327,7 @@ class PluginResultIterator(PluginResult):
def __init__(self, repeat):
self.repeat = repeat

def __iter__(self):
def __iter__(self) -> T.Iterator:
return self

def __next__(self):
Expand Down
2 changes: 1 addition & 1 deletion snowfakery/salesforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def create_cci_record_type_tables(db_url: str):
_populate_rt_table(connection, table, record_type_column, rt_table)


def _create_record_type_table(tablename: str, metadata: MetaData):
def _create_record_type_table(tablename: str, metadata: MetaData) -> Table:
"""Create a table to store mapping between Record Type Ids and Developer Names."""
rt_map_fields = [
Column("record_type_id", Unicode(18), primary_key=True),
Expand Down
20 changes: 4 additions & 16 deletions snowfakery/standard_plugins/Salesforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,26 +330,14 @@ def _load_dataset(self, iteration_mode, rootpath, kwargs):
f"Unable to query records for {query}: {','.join(qs.job_result.job_errors)}"
)

self.tempdir, self.iterator = create_tempfile_sql_db_iterator(
tempdir, iterator = create_tempfile_sql_db_iterator(
iteration_mode, fieldnames, qs.get_results()
)
return self.iterator
iterator.cleanup.push(tempdir)
return iterator

def close(self):
if self.iterator:
self.iterator.close()
self.iterator = None

if self.tempdir:
self.tempdir.cleanup()
self.tempdir = None

def __del__(self):
# in case close was not called
# properly, try to do an orderly
# cleanup
self.close()

pass

def create_tempfile_sql_db_iterator(mode, fieldnames, results):
tempdir, db_url = _create_db(fieldnames, results)
Expand Down
4 changes: 2 additions & 2 deletions snowfakery/standard_plugins/Schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _process_special_cases(
add_date = self.ruleset.rdate
else: # pragma: no cover - Should be unreachable
assert action in ("include", "exclude"), "Bad action!"
raise NotImplementedError()
raise NotImplementedError("Bad action!")

if isinstance(case, (list, tuple)):
for case in case:
Expand Down Expand Up @@ -338,7 +338,7 @@ def next(self) -> None: # pragma: no cover
"""This method is never called.
It is replaced at runtime by _next_datetime or _next_date"""
raise NotImplementedError()
raise NotImplementedError("next is not implemented")

def _next_datetime(self) -> datetime:
return next(self.iterator)
Expand Down
22 changes: 17 additions & 5 deletions snowfakery/standard_plugins/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,17 @@ class DatasetIteratorBase(PluginResultIterator):
Subclasses should implement 'self.restart' which puts an iterator into 'self.results'
"""

def __init__(self, repeat):
# subclasses can register stuff to be cleaned up here.
self.cleanup = ExitStack()
super().__init__(repeat)

def next_result(self):
return next(self.results)

def close(self):
self.cleanup.close()


class SQLDatasetIterator(DatasetIteratorBase):
def __init__(self, engine, table, repeat):
Expand All @@ -86,6 +94,7 @@ def start(self):
def close(self):
self.results = None
self.connection.close()
super().close()

def query(self):
"Return a SQL Alchemy SELECT statement"
Expand All @@ -108,14 +117,13 @@ def query(self):

class CSVDatasetLinearIterator(DatasetIteratorBase):
def __init__(self, datasource: FileLike, repeat: bool):
self.cleanup = ExitStack()
super().__init__(repeat)
# utf-8-sig and newline="" are for Windows
self.path, self.file = self.cleanup.enter_context(
open_file_like(datasource, "r", newline="", encoding="utf-8-sig")
)

self.start()
super().__init__(repeat)

def start(self):
assert self.file
Expand All @@ -127,7 +135,7 @@ def start(self):

def close(self):
self.results = None
self.cleanup.close()
super().close()

def plugin_result(self, row):
if None in row:
Expand Down Expand Up @@ -190,13 +198,17 @@ def _get_dataset_instance(self, plugin_context, iteration_mode, kwargs):
return dataset_instance

def _load_dataset(self, iteration_mode, rootpath, kwargs):
raise NotImplementedError()
raise NotImplementedError("_load_dataset not implemented")

def close(self):
raise NotImplementedError()
raise NotImplementedError("close not implemented: " + repr(self))


class FileDataset(DatasetBase):

def close(self):
pass

def _load_dataset(self, iteration_mode, rootpath, kwargs):
dataset = kwargs.get("dataset")
tablename = kwargs.get("table")
Expand Down
17 changes: 17 additions & 0 deletions tests/multiple-datasets.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
- plugin: snowfakery.standard_plugins.Salesforce.SOQLDataset
- object: Contact
count: 10
fields:
__users_from_salesforce:
SOQLDataset.shuffle:
fields: Id, FirstName, LastName
from: User
__Account_from_Salesforce:
SOQLDataset.shuffle:
fields: Id
from: Account
# The next line depends on the users having particular
# permissions.
FirstName: ${{__users_from_salesforce.FirstName}}
LastName: ${{__users_from_salesforce.LastName}}
AccountId: ${{__Account_from_Salesforce.Id}}

0 comments on commit 384cc2a

Please sign in to comment.