diff --git a/deepmatcher/data/dataset.py b/deepmatcher/data/dataset.py index a589508..8609575 100644 --- a/deepmatcher/data/dataset.py +++ b/deepmatcher/data/dataset.py @@ -69,6 +69,16 @@ def split(table, tables[i].to_csv(os.path.join(path, prefixes[i]), index=False) +class CountingWrapper(object): + def __init__(self, stream): + self.line_count = 0 + self.f = stream + + def read(self, *args, **kwargs): + self.line_count += 1 + return self.f.read(*args, **kwargs) + + class MatchingDataset(data.Dataset): r"""Represents dataset with associated metadata. @@ -101,7 +111,7 @@ class CacheStaleException(Exception): def __init__(self, fields, column_naming, - path=None, + stream=None, format='csv', examples=None, metadata=None, @@ -141,26 +151,29 @@ def __init__(self, """ if examples is None: make_example = { - 'json': Example.fromJSON, 'dict': Example.fromdict, - 'tsv': Example.fromCSV, 'csv': Example.fromCSV}[format.lower()] + 'json': Example.fromJSON, + 'dict': Example.fromdict, + 'tsv': Example.fromCSV, + 'csv': Example.fromCSV + }[format.lower()] + f = CountingWrapper(stream) lines = 0 - with open(os.path.expanduser(path), encoding="utf8") as f: - for line in f: - lines += 1 - - with open(os.path.expanduser(path), encoding="utf8") as f: - if format == 'csv': - reader = unicode_csv_reader(f) - elif format == 'tsv': - reader = unicode_csv_reader(f, delimiter='\t') - else: - reader = f - - next(reader) - examples = [make_example(line, fields) for line in - pyprind.prog_bar(reader, iterations=lines, - title='\nReading and processing data from "' + path + '"')] + if format == 'csv': + reader = unicode_csv_reader(f) + elif format == 'tsv': + reader = unicode_csv_reader(f, delimiter='\t') + else: + reader = f + lines = f.line_count + + next(reader) + examples = [ + make_example(line, fields) for line in pyprind.prog_bar( + reader, + iterations=lines, + title='\nReading and processing data from "' + path + '"') + ] super(MatchingDataset, self).__init__(examples, fields, **kwargs) else: