Skip to content

Commit

Permalink
pep8 format
Browse files Browse the repository at this point in the history
  • Loading branch information
Mistobaan committed May 27, 2020
1 parent 100fa3a commit 50192cd
Show file tree
Hide file tree
Showing 15 changed files with 653 additions and 511 deletions.
14 changes: 8 additions & 6 deletions deepmatcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,21 @@


def process(*args, **kwargs):
warnings.warn('"deepmatcher.process" is deprecated and will be removed in a later '
'release, please use "deepmatcher.data.process" instead',
DeprecationWarning)
warnings.warn(
'"deepmatcher.process" is deprecated and will be removed in a later '
'release, please use "deepmatcher.data.process" instead',
DeprecationWarning)
return data_process(*args, **kwargs)


__version__ = '0.1.1'
__author__ = 'Sidharth Mudgal, Han Li'

__all__ = [
'attr_summarizers', 'word_aggregators', 'word_comparators', 'word_contextualizers',
'process', 'MatchingModel', 'AttrSummarizer', 'WordContextualizer', 'WordComparator',
'WordAggregator', 'Classifier', 'modules'
'attr_summarizers', 'word_aggregators', 'word_comparators',
'word_contextualizers', 'process', 'MatchingModel', 'AttrSummarizer',
'WordContextualizer', 'WordComparator', 'WordAggregator', 'Classifier',
'modules'
]

_check_nan = True
Expand Down
15 changes: 7 additions & 8 deletions deepmatcher/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class AttrTensor(AttrTensor_):
assert(name_attr.word_probs == word_probs)
assert(name_attr.pc == pc)
"""

@staticmethod
def __new__(cls, *args, **kwargs):
if len(kwargs) == 0:
Expand All @@ -47,8 +46,8 @@ def __new__(cls, *args, **kwargs):
word_probs = None
if 'word_probs' in train_info.metadata:
raw_word_probs = train_info.metadata['word_probs'][name]
word_probs = torch.Tensor(
[[raw_word_probs[int(w)] for w in b] for b in data.data])
word_probs = torch.Tensor([[raw_word_probs[int(w)] for w in b]
for b in data.data])
if data.is_cuda:
word_probs = word_probs.cuda()
pc = None
Expand Down Expand Up @@ -86,14 +85,14 @@ class MatchingBatch(object):
name_attr = mbatch.name
category_attr = mbatch.category
"""

def __init__(self, input, train_info):
copy_fields = train_info.all_text_fields
for name in copy_fields:
setattr(self, name,
AttrTensor(
name=name, attr=getattr(input, name),
train_info=train_info))
setattr(
self, name,
AttrTensor(name=name,
attr=getattr(input, name),
train_info=train_info))
for name in [train_info.label_field, train_info.id_field]:
if name is not None and hasattr(input, name):
setattr(self, name, getattr(input, name))
127 changes: 84 additions & 43 deletions deepmatcher/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,9 @@ class MatchingDataset(data.Dataset):
label_field (str): Name of the column containing labels.
id_field (str): Name of the column containing tuple pair ids.
"""

class CacheStaleException(Exception):
r"""Raised when the dataset cache is stale and no fallback behavior is specified.
"""
pass

def __init__(self,
fields,
Expand Down Expand Up @@ -143,8 +141,11 @@ def __init__(self,
"""
if examples is None:
make_example = {
'json': Example.fromJSON, 'dict': Example.fromdict,
'tsv': Example.fromCSV, 'csv': Example.fromCSV}[format.lower()]
'json': Example.fromJSON,
'dict': Example.fromdict,
'tsv': Example.fromCSV,
'csv': Example.fromCSV
}[format.lower()]

lines = 0
with open(os.path.expanduser(path), encoding="utf8") as f:
Expand All @@ -160,9 +161,13 @@ def __init__(self,
reader = f

next(reader)
examples = [make_example(line, fields) for line in
pyprind.prog_bar(reader, iterations=lines,
title='\nReading and processing data from "' + path + '"')]
examples = [
make_example(line, fields) for line in pyprind.prog_bar(
reader,
iterations=lines,
title='\nReading and processing data from "' + path +
'"')
]

super(MatchingDataset, self).__init__(examples, fields, **kwargs)
else:
Expand All @@ -182,12 +187,14 @@ def _set_attributes(self):

self.all_left_fields = []
for name, field in six.iteritems(self.fields):
if name.startswith(self.column_naming['left']) and field is not None:
if name.startswith(
self.column_naming['left']) and field is not None:
self.all_left_fields.append(name)

self.all_right_fields = []
for name, field in six.iteritems(self.fields):
if name.startswith(self.column_naming['right']) and field is not None:
if name.startswith(
self.column_naming['right']) and field is not None:
self.all_right_fields.append(name)

self.canonical_text_fields = []
Expand Down Expand Up @@ -224,13 +231,18 @@ def compute_metadata(self, pca=False):
self.metadata = {}

# Create an iterator over the entire dataset.
train_iter = MatchingIterator(
self, self, train=False, batch_size=1024, device='cpu', sort_in_buckets=False)
train_iter = MatchingIterator(self,
self,
train=False,
batch_size=1024,
device='cpu',
sort_in_buckets=False)
counter = defaultdict(Counter)

# For each attribute, find the number of times each word id occurs in the dataset.
# Note that word ids here also include ``UNK`` tokens, padding tokens, etc.
for batch in pyprind.prog_bar(train_iter, title='\nBuilding vocabulary'):
for batch in pyprind.prog_bar(train_iter,
title='\nBuilding vocabulary'):
for name in self.all_text_fields:
attr_input = getattr(batch, name)
counter[name].update(attr_input.data.data.view(-1).tolist())
Expand Down Expand Up @@ -269,14 +281,18 @@ def compute_metadata(self, pca=False):
embed[name] = field_embed[field]

# Create an iterator over the entire dataset.
train_iter = MatchingIterator(
self, self, train=False, batch_size=1024, device='cpu', sort_in_buckets=False)
train_iter = MatchingIterator(self,
self,
train=False,
batch_size=1024,
device='cpu',
sort_in_buckets=False)
attr_embeddings = defaultdict(list)

# Run the constructed neural network to compute weighted sequence embeddings
# for each attribute of each example in the dataset.
for batch in pyprind.prog_bar(train_iter,
title='\nComputing principal components'):
for batch in pyprind.prog_bar(
train_iter, title='\nComputing principal components'):
for name in self.all_text_fields:
attr_input = getattr(batch, name)
embeddings = inv_freq_pool(embed[name](attr_input))
Expand Down Expand Up @@ -312,7 +328,8 @@ def get_raw_table(self):
using the whitespace delimiter.
"""
rows = []
columns = list(name for name, field in six.iteritems(self.fields) if field)
columns = list(name for name, field in six.iteritems(self.fields)
if field)
for ex in self.examples:
row = []
for attr in columns:
Expand All @@ -331,10 +348,12 @@ def sort_key(self, ex):
A key to use for sorting dataset examples for batching together examples with
similar lengths to minimize padding."""

return interleave_keys([len(getattr(ex, attr)) for attr in self.all_text_fields])
return interleave_keys(
[len(getattr(ex, attr)) for attr in self.all_text_fields])

@staticmethod
def save_cache(datasets, fields, datafiles, cachefile, column_naming, state_args):
def save_cache(datasets, fields, datafiles, cachefile, column_naming,
state_args):
r"""Save datasets and corresponding metadata to cache.
This method also saves as many data loading arguments as possible to help ensure
Expand All @@ -355,7 +374,9 @@ def save_cache(datasets, fields, datafiles, cachefile, column_naming, state_args
"""
examples = [dataset.examples for dataset in datasets]
train_metadata = datasets[0].metadata
datafiles_modified = [os.path.getmtime(datafile) for datafile in datafiles]
datafiles_modified = [
os.path.getmtime(datafile) for datafile in datafiles
]
vocabs = {}
field_args = {}
reverse_fields = {}
Expand Down Expand Up @@ -423,18 +444,23 @@ def load_cache(fields, datafiles, cachefile, column_naming, state_args):
if datafiles != cached_data['datafiles']:
cache_stale_cause.add('Data file list has changed.')

datafiles_modified = [os.path.getmtime(datafile) for datafile in datafiles]
datafiles_modified = [
os.path.getmtime(datafile) for datafile in datafiles
]
if datafiles_modified != cached_data['datafiles_modified']:
cache_stale_cause.add('One or more data files have been modified.')

if set(fields.keys()) != set(cached_data['field_args'].keys()):
cache_stale_cause.add('Fields have changed.')

for name, field in six.iteritems(fields):
none_mismatch = (field is None) != (cached_data['field_args'][name] is None)
none_mismatch = (field is None) != (cached_data['field_args'][name]
is None)
args_mismatch = False
if field is not None and cached_data['field_args'][name] is not None:
args_mismatch = field.preprocess_args() != cached_data['field_args'][name]
if field is not None and cached_data['field_args'][
name] is not None:
args_mismatch = field.preprocess_args(
) != cached_data['field_args'][name]
if none_mismatch or args_mismatch:
cache_stale_cause.add('Field arguments have changed.')
if field is not None and not isinstance(field, MatchingField):
Expand All @@ -444,8 +470,8 @@ def load_cache(fields, datafiles, cachefile, column_naming, state_args):
cache_stale_cause.add('Other arguments have changed.')

cache_stale_cause.update(
MatchingDataset.state_args_compatibility(state_args,
cached_data['state_args']))
MatchingDataset.state_args_compatibility(
state_args, cached_data['state_args']))

return cached_data, cache_stale_cause

Expand Down Expand Up @@ -535,59 +561,75 @@ def splits(cls,

datasets = None
if cache:
datafiles = list(f for f in (train, validation, test) if f is not None)
datafiles = [os.path.expanduser(os.path.join(path, d)) for d in datafiles]
datafiles = list(f for f in (train, validation, test)
if f is not None)
datafiles = [
os.path.expanduser(os.path.join(path, d)) for d in datafiles
]
cachefile = os.path.expanduser(os.path.join(path, cache))
try:
cached_data, cache_stale_cause = MatchingDataset.load_cache(
fields_dict, datafiles, cachefile, column_naming, state_args)
fields_dict, datafiles, cachefile, column_naming,
state_args)

if check_cached_data and cache_stale_cause:
if not auto_rebuild_cache:
raise MatchingDataset.CacheStaleException(cache_stale_cause)
raise MatchingDataset.CacheStaleException(
cache_stale_cause)
else:
logger.warning('Rebuilding data cache because: %s', list(cache_stale_cause))
logger.warning('Rebuilding data cache because: %s',
list(cache_stale_cause))

if not check_cached_data or not cache_stale_cause:
datasets = MatchingDataset.restore_data(fields, cached_data)
datasets = MatchingDataset.restore_data(
fields, cached_data)

except IOError:
pass

if not datasets:
begin = timer()
dataset_args = {'fields': fields, 'column_naming': column_naming, **kwargs}
dataset_args = {
'fields': fields,
'column_naming': column_naming,
**kwargs
}
train_data = None if train is None else cls(
path=os.path.join(path, train), **dataset_args)
val_data = None if validation is None else cls(
path=os.path.join(path, validation), **dataset_args)
test_data = None if test is None else cls(
path=os.path.join(path, test), **dataset_args)
datasets = tuple(
d for d in (train_data, val_data, test_data) if d is not None)
datasets = tuple(d for d in (train_data, val_data, test_data)
if d is not None)

after_load = timer()
logger.info('Data load took: {}s'.format(after_load - begin))

fields_set = set(fields_dict.values())
for field in fields_set:
if field is not None and field.use_vocab:
field.build_vocab(
*datasets, vectors=embeddings, cache=embeddings_cache)
field.build_vocab(*datasets,
vectors=embeddings,
cache=embeddings_cache)
after_vocab = timer()
logger.info('Vocab construction time: {}s'.format(after_vocab - after_load))
logger.info('Vocab construction time: {}s'.format(after_vocab -
after_load))

if train:
datasets[0].compute_metadata(train_pca)
after_metadata = timer()
logger.info(
'Metadata computation time: {}s'.format(after_metadata - after_vocab))
'Metadata computation time: {}s'.format(after_metadata -
after_vocab))

if cache:
MatchingDataset.save_cache(datasets, fields_dict, datafiles, cachefile,
column_naming, state_args)
MatchingDataset.save_cache(datasets, fields_dict, datafiles,
cachefile, column_naming,
state_args)
after_cache = timer()
logger.info('Cache save time: {}s'.format(after_cache - after_vocab))
logger.info('Cache save time: {}s'.format(after_cache -
after_vocab))

if train:
datasets[0].finalize_metadata()
Expand Down Expand Up @@ -616,7 +658,6 @@ def interleave_keys(keys):
values for the key defined by this function. Useful for tasks with two
text fields like machine translation or natural language inference.
"""

def interleave(args):
return ''.join([x for t in zip(*args) for x in t])

Expand Down
Loading

0 comments on commit 50192cd

Please sign in to comment.