Skip to content

Commit

Permalink
fix stream logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Mistobaan committed May 22, 2020
1 parent 300d101 commit 0aac300
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 16 deletions.
4 changes: 2 additions & 2 deletions deepmatcher/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .field import MatchingField, reset_vector_cache
from .dataset import MatchingDataset
from .iterator import MatchingIterator
from .process import process, process_unlabeled
from .process import process, process_unlabeled, process_unlabeled_stream
from .dataset import split

__all__ = [
'MatchingField', 'MatchingDataset', 'MatchingIterator', 'process', 'process_unlabeled', 'split',
'MatchingField', 'MatchingDataset', 'MatchingIterator', 'process', 'process_unlabeled', 'process_unlabeled_stream', 'split',
'reset_vector_cache'
]
21 changes: 8 additions & 13 deletions deepmatcher/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import os
import pdb
from collections import Counter, defaultdict
from collections import Counter, defaultdict, Iterator
from timeit import default_timer as timer

import pandas as pd
Expand Down Expand Up @@ -69,14 +69,14 @@ def split(table,
tables[i].to_csv(os.path.join(path, prefixes[i]), index=False)


class CountingWrapper(object):
class CountingWrapper(Iterator):
def __init__(self, stream):
self.line_count = 0
self.f = stream

def read(self, *args, **kwargs):
def __next__(self):
self.line_count += 1
return self.f.read(*args, **kwargs)
return self.f.readline()


class MatchingDataset(data.Dataset):
Expand Down Expand Up @@ -167,13 +167,7 @@ def __init__(self,
reader = f
lines = f.line_count

next(reader)
examples = [
make_example(line, fields) for line in pyprind.prog_bar(
reader,
iterations=lines,
title='\nReading and processing data from "' + path + '"')
]
examples = [ make_example(line, fields) for line in reader ]

super(MatchingDataset, self).__init__(examples, fields, **kwargs)
else:
Expand Down Expand Up @@ -216,6 +210,7 @@ def _set_attributes(self):
self.label_field = self.column_naming['label']
self.id_field = self.column_naming['id']


def compute_metadata(self, pca=False):
r"""Computes metadata about the dataset.
Expand Down Expand Up @@ -634,7 +629,7 @@ def splits(cls,
cachefile, column_naming,
state_args)
after_cache = timer()
logger.info('Cache save time: {time}s', time=after_cache -
logger.info('Cache save time: {time}s', time=(after_cache -
after_vocab))

if train:
Expand Down
3 changes: 2 additions & 1 deletion deepmatcher/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def process_unlabeled(path, trained_model, ignore_columns=None):
with io.open(path, encoding="utf8") as f:
return process_unlabeled_stream(f, trained_model, ignore_columns)


def process_unlabeled_stream(stream, trained_model, ignore_columns=None):
"""Creates a dataset object for an unlabeled dataset.
Expand All @@ -274,7 +275,7 @@ def process_unlabeled_stream(stream, trained_model, ignore_columns=None):
train_info.tokenize, train_info.include_lengths)

begin = timer()
dataset_args = {'fields': fields, 'column_naming': column_naming}
dataset_args = {'fields': fields, 'column_naming': column_naming, 'format':'csv'}
dataset = MatchingDataset(stream=stream, **dataset_args)

# Make sure we have the same attributes.
Expand Down

0 comments on commit 0aac300

Please sign in to comment.