Skip to content

Commit

Permalink
add support for stream input
Browse files Browse the repository at this point in the history
  • Loading branch information
Mistobaan committed May 22, 2020
1 parent 78c90f9 commit 300d101
Showing 1 changed file with 32 additions and 19 deletions.
51 changes: 32 additions & 19 deletions deepmatcher/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ def split(table,
tables[i].to_csv(os.path.join(path, prefixes[i]), index=False)


class CountingWrapper(object):
def __init__(self, stream):
self.line_count = 0
self.f = stream

def read(self, *args, **kwargs):
self.line_count += 1
return self.f.read(*args, **kwargs)


class MatchingDataset(data.Dataset):
r"""Represents dataset with associated metadata.
Expand Down Expand Up @@ -101,7 +111,7 @@ class CacheStaleException(Exception):
def __init__(self,
fields,
column_naming,
path=None,
stream=None,
format='csv',
examples=None,
metadata=None,
Expand Down Expand Up @@ -141,26 +151,29 @@ def __init__(self,
"""
if examples is None:
make_example = {
'json': Example.fromJSON, 'dict': Example.fromdict,
'tsv': Example.fromCSV, 'csv': Example.fromCSV}[format.lower()]
'json': Example.fromJSON,
'dict': Example.fromdict,
'tsv': Example.fromCSV,
'csv': Example.fromCSV
}[format.lower()]

f = CountingWrapper(stream)
lines = 0
with open(os.path.expanduser(path), encoding="utf8") as f:
for line in f:
lines += 1

with open(os.path.expanduser(path), encoding="utf8") as f:
if format == 'csv':
reader = unicode_csv_reader(f)
elif format == 'tsv':
reader = unicode_csv_reader(f, delimiter='\t')
else:
reader = f

next(reader)
examples = [make_example(line, fields) for line in
pyprind.prog_bar(reader, iterations=lines,
title='\nReading and processing data from "' + path + '"')]
if format == 'csv':
reader = unicode_csv_reader(f)
elif format == 'tsv':
reader = unicode_csv_reader(f, delimiter='\t')
else:
reader = f
lines = f.line_count

next(reader)
examples = [
make_example(line, fields) for line in pyprind.prog_bar(
reader,
iterations=lines,
title='\nReading and processing data from "' + path + '"')
]

super(MatchingDataset, self).__init__(examples, fields, **kwargs)
else:
Expand Down

0 comments on commit 300d101

Please sign in to comment.