diff --git a/deepmatcher/__init__.py b/deepmatcher/__init__.py index 4f1f7be..6ca8779 100644 --- a/deepmatcher/__init__.py +++ b/deepmatcher/__init__.py @@ -29,9 +29,10 @@ def process(*args, **kwargs): - warnings.warn('"deepmatcher.process" is deprecated and will be removed in a later ' - 'release, please use "deepmatcher.data.process" instead', - DeprecationWarning) + warnings.warn( + '"deepmatcher.process" is deprecated and will be removed in a later ' + 'release, please use "deepmatcher.data.process" instead', + DeprecationWarning) return data_process(*args, **kwargs) @@ -39,9 +40,10 @@ def process(*args, **kwargs): __author__ = 'Sidharth Mudgal, Han Li' __all__ = [ - 'attr_summarizers', 'word_aggregators', 'word_comparators', 'word_contextualizers', - 'process', 'MatchingModel', 'AttrSummarizer', 'WordContextualizer', 'WordComparator', - 'WordAggregator', 'Classifier', 'modules' + 'attr_summarizers', 'word_aggregators', 'word_comparators', + 'word_contextualizers', 'process', 'MatchingModel', 'AttrSummarizer', + 'WordContextualizer', 'WordComparator', 'WordAggregator', 'Classifier', + 'modules' ] _check_nan = True diff --git a/deepmatcher/batch.py b/deepmatcher/batch.py index 4b05f94..5903284 100644 --- a/deepmatcher/batch.py +++ b/deepmatcher/batch.py @@ -29,7 +29,6 @@ class AttrTensor(AttrTensor_): assert(name_attr.word_probs == word_probs) assert(name_attr.pc == pc) """ - @staticmethod def __new__(cls, *args, **kwargs): if len(kwargs) == 0: @@ -47,8 +46,8 @@ def __new__(cls, *args, **kwargs): word_probs = None if 'word_probs' in train_info.metadata: raw_word_probs = train_info.metadata['word_probs'][name] - word_probs = torch.Tensor( - [[raw_word_probs[int(w)] for w in b] for b in data.data]) + word_probs = torch.Tensor([[raw_word_probs[int(w)] for w in b] + for b in data.data]) if data.is_cuda: word_probs = word_probs.cuda() pc = None @@ -86,14 +85,14 @@ class MatchingBatch(object): name_attr = mbatch.name category_attr = mbatch.category """ - def __init__(self, input, train_info): copy_fields = train_info.all_text_fields for name in copy_fields: - setattr(self, name, - AttrTensor( - name=name, attr=getattr(input, name), - train_info=train_info)) + setattr( + self, name, + AttrTensor(name=name, + attr=getattr(input, name), + train_info=train_info)) for name in [train_info.label_field, train_info.id_field]: if name is not None and hasattr(input, name): setattr(self, name, getattr(input, name)) diff --git a/deepmatcher/data/dataset.py b/deepmatcher/data/dataset.py index 36a657f..23fd094 100644 --- a/deepmatcher/data/dataset.py +++ b/deepmatcher/data/dataset.py @@ -94,11 +94,9 @@ class MatchingDataset(data.Dataset): label_field (str): Name of the column containing labels. id_field (str): Name of the column containing tuple pair ids. """ - class CacheStaleException(Exception): r"""Raised when the dataset cache is stale and no fallback behavior is specified. """ - pass def __init__(self, fields, @@ -143,8 +141,11 @@ def __init__(self, """ if examples is None: make_example = { - 'json': Example.fromJSON, 'dict': Example.fromdict, - 'tsv': Example.fromCSV, 'csv': Example.fromCSV}[format.lower()] + 'json': Example.fromJSON, + 'dict': Example.fromdict, + 'tsv': Example.fromCSV, + 'csv': Example.fromCSV + }[format.lower()] lines = 0 with open(os.path.expanduser(path), encoding="utf8") as f: @@ -160,9 +161,13 @@ def __init__(self, reader = f next(reader) - examples = [make_example(line, fields) for line in - pyprind.prog_bar(reader, iterations=lines, - title='\nReading and processing data from "' + path + '"')] + examples = [ + make_example(line, fields) for line in pyprind.prog_bar( + reader, + iterations=lines, + title='\nReading and processing data from "' + path + + '"') + ] super(MatchingDataset, self).__init__(examples, fields, **kwargs) else: @@ -182,12 +187,14 @@ def _set_attributes(self): self.all_left_fields = [] for name, field in six.iteritems(self.fields): - if name.startswith(self.column_naming['left']) and field is not None: + if name.startswith( + self.column_naming['left']) and field is not None: self.all_left_fields.append(name) self.all_right_fields = [] for name, field in six.iteritems(self.fields): - if name.startswith(self.column_naming['right']) and field is not None: + if name.startswith( + self.column_naming['right']) and field is not None: self.all_right_fields.append(name) self.canonical_text_fields = [] @@ -224,13 +231,18 @@ def compute_metadata(self, pca=False): self.metadata = {} # Create an iterator over the entire dataset. - train_iter = MatchingIterator( - self, self, train=False, batch_size=1024, device='cpu', sort_in_buckets=False) + train_iter = MatchingIterator(self, + self, + train=False, + batch_size=1024, + device='cpu', + sort_in_buckets=False) counter = defaultdict(Counter) # For each attribute, find the number of times each word id occurs in the dataset. # Note that word ids here also include ``UNK`` tokens, padding tokens, etc. - for batch in pyprind.prog_bar(train_iter, title='\nBuilding vocabulary'): + for batch in pyprind.prog_bar(train_iter, + title='\nBuilding vocabulary'): for name in self.all_text_fields: attr_input = getattr(batch, name) counter[name].update(attr_input.data.data.view(-1).tolist()) @@ -269,14 +281,18 @@ def compute_metadata(self, pca=False): embed[name] = field_embed[field] # Create an iterator over the entire dataset. - train_iter = MatchingIterator( - self, self, train=False, batch_size=1024, device='cpu', sort_in_buckets=False) + train_iter = MatchingIterator(self, + self, + train=False, + batch_size=1024, + device='cpu', + sort_in_buckets=False) attr_embeddings = defaultdict(list) # Run the constructed neural network to compute weighted sequence embeddings # for each attribute of each example in the dataset. - for batch in pyprind.prog_bar(train_iter, - title='\nComputing principal components'): + for batch in pyprind.prog_bar( + train_iter, title='\nComputing principal components'): for name in self.all_text_fields: attr_input = getattr(batch, name) embeddings = inv_freq_pool(embed[name](attr_input)) @@ -312,7 +328,8 @@ def get_raw_table(self): using the whitespace delimiter. """ rows = [] - columns = list(name for name, field in six.iteritems(self.fields) if field) + columns = list(name for name, field in six.iteritems(self.fields) + if field) for ex in self.examples: row = [] for attr in columns: @@ -331,10 +348,12 @@ def sort_key(self, ex): A key to use for sorting dataset examples for batching together examples with similar lengths to minimize padding.""" - return interleave_keys([len(getattr(ex, attr)) for attr in self.all_text_fields]) + return interleave_keys( + [len(getattr(ex, attr)) for attr in self.all_text_fields]) @staticmethod - def save_cache(datasets, fields, datafiles, cachefile, column_naming, state_args): + def save_cache(datasets, fields, datafiles, cachefile, column_naming, + state_args): r"""Save datasets and corresponding metadata to cache. This method also saves as many data loading arguments as possible to help ensure @@ -355,7 +374,9 @@ def save_cache(datasets, fields, datafiles, cachefile, column_naming, state_args """ examples = [dataset.examples for dataset in datasets] train_metadata = datasets[0].metadata - datafiles_modified = [os.path.getmtime(datafile) for datafile in datafiles] + datafiles_modified = [ + os.path.getmtime(datafile) for datafile in datafiles + ] vocabs = {} field_args = {} reverse_fields = {} @@ -423,7 +444,9 @@ def load_cache(fields, datafiles, cachefile, column_naming, state_args): if datafiles != cached_data['datafiles']: cache_stale_cause.add('Data file list has changed.') - datafiles_modified = [os.path.getmtime(datafile) for datafile in datafiles] + datafiles_modified = [ + os.path.getmtime(datafile) for datafile in datafiles + ] if datafiles_modified != cached_data['datafiles_modified']: cache_stale_cause.add('One or more data files have been modified.') @@ -431,10 +454,13 @@ def load_cache(fields, datafiles, cachefile, column_naming, state_args): cache_stale_cause.add('Fields have changed.') for name, field in six.iteritems(fields): - none_mismatch = (field is None) != (cached_data['field_args'][name] is None) + none_mismatch = (field is None) != (cached_data['field_args'][name] + is None) args_mismatch = False - if field is not None and cached_data['field_args'][name] is not None: - args_mismatch = field.preprocess_args() != cached_data['field_args'][name] + if field is not None and cached_data['field_args'][ + name] is not None: + args_mismatch = field.preprocess_args( + ) != cached_data['field_args'][name] if none_mismatch or args_mismatch: cache_stale_cause.add('Field arguments have changed.') if field is not None and not isinstance(field, MatchingField): @@ -444,8 +470,8 @@ def load_cache(fields, datafiles, cachefile, column_naming, state_args): cache_stale_cause.add('Other arguments have changed.') cache_stale_cause.update( - MatchingDataset.state_args_compatibility(state_args, - cached_data['state_args'])) + MatchingDataset.state_args_compatibility( + state_args, cached_data['state_args'])) return cached_data, cache_stale_cause @@ -535,36 +561,47 @@ def splits(cls, datasets = None if cache: - datafiles = list(f for f in (train, validation, test) if f is not None) - datafiles = [os.path.expanduser(os.path.join(path, d)) for d in datafiles] + datafiles = list(f for f in (train, validation, test) + if f is not None) + datafiles = [ + os.path.expanduser(os.path.join(path, d)) for d in datafiles + ] cachefile = os.path.expanduser(os.path.join(path, cache)) try: cached_data, cache_stale_cause = MatchingDataset.load_cache( - fields_dict, datafiles, cachefile, column_naming, state_args) + fields_dict, datafiles, cachefile, column_naming, + state_args) if check_cached_data and cache_stale_cause: if not auto_rebuild_cache: - raise MatchingDataset.CacheStaleException(cache_stale_cause) + raise MatchingDataset.CacheStaleException( + cache_stale_cause) else: - logger.warning('Rebuilding data cache because: %s', list(cache_stale_cause)) + logger.warning('Rebuilding data cache because: %s', + list(cache_stale_cause)) if not check_cached_data or not cache_stale_cause: - datasets = MatchingDataset.restore_data(fields, cached_data) + datasets = MatchingDataset.restore_data( + fields, cached_data) except IOError: pass if not datasets: begin = timer() - dataset_args = {'fields': fields, 'column_naming': column_naming, **kwargs} + dataset_args = { + 'fields': fields, + 'column_naming': column_naming, + **kwargs + } train_data = None if train is None else cls( path=os.path.join(path, train), **dataset_args) val_data = None if validation is None else cls( path=os.path.join(path, validation), **dataset_args) test_data = None if test is None else cls( path=os.path.join(path, test), **dataset_args) - datasets = tuple( - d for d in (train_data, val_data, test_data) if d is not None) + datasets = tuple(d for d in (train_data, val_data, test_data) + if d is not None) after_load = timer() logger.info('Data load took: {}s'.format(after_load - begin)) @@ -572,22 +609,27 @@ def splits(cls, fields_set = set(fields_dict.values()) for field in fields_set: if field is not None and field.use_vocab: - field.build_vocab( - *datasets, vectors=embeddings, cache=embeddings_cache) + field.build_vocab(*datasets, + vectors=embeddings, + cache=embeddings_cache) after_vocab = timer() - logger.info('Vocab construction time: {}s'.format(after_vocab - after_load)) + logger.info('Vocab construction time: {}s'.format(after_vocab - + after_load)) if train: datasets[0].compute_metadata(train_pca) after_metadata = timer() logger.info( - 'Metadata computation time: {}s'.format(after_metadata - after_vocab)) + 'Metadata computation time: {}s'.format(after_metadata - + after_vocab)) if cache: - MatchingDataset.save_cache(datasets, fields_dict, datafiles, cachefile, - column_naming, state_args) + MatchingDataset.save_cache(datasets, fields_dict, datafiles, + cachefile, column_naming, + state_args) after_cache = timer() - logger.info('Cache save time: {}s'.format(after_cache - after_vocab)) + logger.info('Cache save time: {}s'.format(after_cache - + after_vocab)) if train: datasets[0].finalize_metadata() @@ -616,7 +658,6 @@ def interleave_keys(keys): values for the key defined by this function. Useful for tasks with two text fields like machine translation or natural language inference. """ - def interleave(args): return ''.join([x for t in zip(*args) for x in t]) diff --git a/deepmatcher/data/process.py b/deepmatcher/data/process.py index 8591eaa..f0ae688 100644 --- a/deepmatcher/data/process.py +++ b/deepmatcher/data/process.py @@ -14,7 +14,8 @@ logger = logging.getLogger(__name__) -def _check_header(header, id_attr, left_prefix, right_prefix, label_attr, ignore_columns): +def _check_header(header, id_attr, left_prefix, right_prefix, label_attr, + ignore_columns): r"""Verify CSV file header. Checks that: @@ -30,10 +31,12 @@ def _check_header(header, id_attr, left_prefix, right_prefix, label_attr, ignore for attr in header: if attr not in (id_attr, label_attr) and attr not in ignore_columns: - if not attr.startswith(left_prefix) and not attr.startswith(right_prefix): - raise ValueError('Attribute ' + attr + ' is not a left or a right table ' - 'column, not a label or id and is not ignored. Not sure ' - 'what it is...') + if not attr.startswith(left_prefix) and not attr.startswith( + right_prefix): + raise ValueError( + 'Attribute ' + attr + ' is not a left or a right table ' + 'column, not a label or id and is not ignored. Not sure ' + 'what it is...') num_left = sum(attr.startswith(left_prefix) for attr in header) num_right = sum(attr.startswith(right_prefix) for attr in header) @@ -53,15 +56,15 @@ def _make_fields(header, id_attr, label_attr, ignore_columns, lower, tokenize, in the same order that the columns occur in the CSV file. """ - text_field = MatchingField( - lower=lower, - tokenize=tokenize, - init_token='<<<', - eos_token='>>>', - batch_first=True, - include_lengths=include_lengths) - numeric_field = MatchingField( - sequential=False, preprocessing=lambda x: int(x), use_vocab=False) + text_field = MatchingField(lower=lower, + tokenize=tokenize, + init_token='<<<', + eos_token='>>>', + batch_first=True, + include_lengths=include_lengths) + numeric_field = MatchingField(sequential=False, + preprocessing=int, + use_vocab=False) id_field = MatchingField(sequential=False, use_vocab=False, id=True) fields = [] @@ -190,13 +193,15 @@ def process(path, # TODO(Sid): check for all datasets to make sure the files exist and have the same schema a_dataset = train or validation or test - with io.open(os.path.expanduser(os.path.join(path, a_dataset)), encoding="utf8") as f: + with io.open(os.path.expanduser(os.path.join(path, a_dataset)), + encoding="utf8") as f: header = next(unicode_csv_reader(f)) _maybe_download_nltk_data() - _check_header(header, id_attr, left_prefix, right_prefix, label_attr, ignore_columns) - fields = _make_fields(header, id_attr, label_attr, ignore_columns, lowercase, - tokenize, include_lengths) + _check_header(header, id_attr, left_prefix, right_prefix, label_attr, + ignore_columns) + fields = _make_fields(header, id_attr, label_attr, ignore_columns, + lowercase, tokenize, include_lengths) column_naming = { 'id': id_attr, @@ -205,19 +210,18 @@ def process(path, 'label': label_attr } - datasets = MatchingDataset.splits( - path, - train, - validation, - test, - fields, - embeddings, - embeddings_cache_path, - column_naming, - cache, - check_cached_data, - auto_rebuild_cache, - train_pca=pca) + datasets = MatchingDataset.splits(path, + train, + validation, + test, + fields, + embeddings, + embeddings_cache_path, + column_naming, + cache, + check_cached_data, + auto_rebuild_cache, + train_pca=pca) # Save additional information to train dataset. datasets[0].ignore_columns = ignore_columns @@ -232,7 +236,7 @@ def process_unlabeled(path, trained_model, ignore_columns=None): """Creates a dataset object for an unlabeled dataset. Args: - path (string): + path (str): The full path to the unlabeled data file (not just the directory). trained_model (:class:`~deepmatcher.MatchingModel`): The trained model. The model is aware of the configuration of the training @@ -242,7 +246,22 @@ def process_unlabeled(path, trained_model, ignore_columns=None): A list of columns to ignore in the unlabeled CSV file. """ with io.open(path, encoding="utf8") as f: - header = next(unicode_csv_reader(f)) + return process_unlabeled_stream(f, trained_model, ignore_columns) + +def process_unlabeled_stream(stream, trained_model, ignore_columns=None): + """Creates a dataset object for an unlabeled dataset. + + Args: + stream (io.Stream): + A stream open for reading. + trained_model (:class:`~deepmatcher.MatchingModel`): + The trained model. The model is aware of the configuration of the training + data on which it was trained, and so this method reuses the same + configuration for the unlabeled data. + ignore_columns (list): + A list of columns to ignore in the unlabeled CSV file. + """ + header = next(unicode_csv_reader(stream)) train_info = trained_model.meta if ignore_columns is None: @@ -251,18 +270,18 @@ def process_unlabeled(path, trained_model, ignore_columns=None): column_naming['label'] = None fields = _make_fields(header, column_naming['id'], column_naming['label'], - ignore_columns, train_info.lowercase, train_info.tokenize, - train_info.include_lengths) + ignore_columns, train_info.lowercase, + train_info.tokenize, train_info.include_lengths) begin = timer() dataset_args = {'fields': fields, 'column_naming': column_naming} - dataset = MatchingDataset(path=path, **dataset_args) + dataset = MatchingDataset(stream=stream, **dataset_args) # Make sure we have the same attributes. assert set(dataset.all_text_fields) == set(train_info.all_text_fields) after_load = timer() - logger.info('Data load time: {}s'.format(after_load - begin)) + logger.info('Data load time: {delta}s', delta=(after_load - begin)) reverse_fields_dict = dict((pair[1], pair[0]) for pair in fields) for field, name in reverse_fields_dict.items(): @@ -270,8 +289,9 @@ def process_unlabeled(path, trained_model, ignore_columns=None): # Copy over vocab from original train data. field.vocab = copy.deepcopy(train_info.vocabs[name]) # Then extend the vocab. - field.extend_vocab( - dataset, vectors=train_info.embeddings, cache=train_info.embeddings_cache) + field.extend_vocab(dataset, + vectors=train_info.embeddings, + cache=train_info.embeddings_cache) dataset.vocabs = { name: dataset.fields[name].vocab @@ -279,6 +299,6 @@ def process_unlabeled(path, trained_model, ignore_columns=None): } after_vocab = timer() - logger.info('Vocab update time: {}s'.format(after_vocab - after_load)) + logger.info('Vocab update time: {delta}s', delta=(after_vocab - after_load)) return dataset diff --git a/deepmatcher/models/core.py b/deepmatcher/models/core.py index c36e8a8..1b37bbe 100644 --- a/deepmatcher/models/core.py +++ b/deepmatcher/models/core.py @@ -80,7 +80,6 @@ class MatchingModel(nn.Module): if they are string arguments. If a module or :attr:`callable` input is specified for a component, this argument is ignored for that component. """ - def __init__(self, attr_summarizer='hybrid', attr_condense_factor='auto', @@ -293,16 +292,18 @@ def initialize(self, train_dataset, init_batch=None): self.attr_summarizers[name] = AttrSummarizer._create( summarizer, hidden_size=self.hidden_size) assert len( - set(self.attr_summarizers.keys()) ^ set(self.meta.canonical_text_fields) - ) == 0 + set(self.attr_summarizers.keys()) + ^ set(self.meta.canonical_text_fields)) == 0 else: self.attr_summarizer = AttrSummarizer._create( self.attr_summarizer, hidden_size=self.hidden_size) for name in self.meta.canonical_text_fields: - self.attr_summarizers[name] = copy.deepcopy(self.attr_summarizer) + self.attr_summarizers[name] = copy.deepcopy( + self.attr_summarizer) if self.attr_condense_factor == 'auto': - self.attr_condense_factor = min(len(self.meta.canonical_text_fields), 6) + self.attr_condense_factor = min( + len(self.meta.canonical_text_fields), 6) if self.attr_condense_factor == 1: self.attr_condense_factor = None @@ -319,10 +320,11 @@ def initialize(self, train_dataset, init_batch=None): self.attr_comparators = dm.modules.ModuleMap() if isinstance(self.attr_comparator, Mapping): for name, comparator in self.attr_comparator.items(): - self.attr_comparators[name] = _create_attr_comparator(comparator) + self.attr_comparators[name] = _create_attr_comparator( + comparator) assert len( - set(self.attr_comparators.keys()) ^ set(self.meta.canonical_text_fields) - ) == 0 + set(self.attr_comparators.keys()) + ^ set(self.meta.canonical_text_fields)) == 0 else: if isinstance(self.attr_summarizer, AttrSummarizer): self.attr_comparator = self._get_attr_comparator( @@ -332,25 +334,27 @@ def initialize(self, train_dataset, init_batch=None): raise ValueError('"attr_comparator" must be specified if ' '"attr_summarizer" is custom.') - self.attr_comparator = _create_attr_comparator(self.attr_comparator) + self.attr_comparator = _create_attr_comparator( + self.attr_comparator) for name in self.meta.canonical_text_fields: - self.attr_comparators[name] = copy.deepcopy(self.attr_comparator) + self.attr_comparators[name] = copy.deepcopy( + self.attr_comparator) self.attr_merge = dm.modules._merge_module(self.attr_merge) - self.classifier = _utils.get_module( - Classifier, self.classifier, hidden_size=self.hidden_size) + self.classifier = _utils.get_module(Classifier, + self.classifier, + hidden_size=self.hidden_size) self._reset_embeddings(train_dataset.vocabs) # Instantiate all components using a small batch from training set. if not init_batch: - run_iter = MatchingIterator( - train_dataset, - train_dataset, - train=False, - batch_size=4, - device='cpu', - sort_in_buckets=False) + run_iter = MatchingIterator(train_dataset, + train_dataset, + train=False, + batch_size=4, + device='cpu', + sort_in_buckets=False) init_batch = next(run_iter.__iter__()) self.forward(init_batch) @@ -358,8 +362,10 @@ def initialize(self, train_dataset, init_batch=None): self.state_meta.init_batch = init_batch self._initialized = True - logger.info('Successfully initialized MatchingModel with {:d} trainable ' - 'parameters.'.format(tally_parameters(self))) + logger.info( + 'Successfully initialized MatchingModel with {nparam:d} trainable ' + 'parameters.', + nparam=tally_parameters(self)) def _reset_embeddings(self, vocabs): self.embed = dm.modules.ModuleMap() @@ -421,8 +427,8 @@ def forward(self, input): attr_comparisons = [] for name in self.meta.canonical_text_fields: left, right = self.meta.text_fields[name] - left_summary, right_summary = self.attr_summarizers[name](embeddings[left], - embeddings[right]) + left_summary, right_summary = self.attr_summarizers[name]( + embeddings[left], embeddings[right]) # Remove metadata information at this point. left_summary, right_summary = left_summary.data, right_summary.data @@ -512,7 +518,6 @@ class AttrSummarizer(dm.modules.LazyModule): size of the last dimension of the input to this module as the hidden size. Defaults to None. """ - def _init(self, word_contextualizer, word_comparator, @@ -520,10 +525,10 @@ def _init(self, hidden_size=None): self.word_contextualizer = WordContextualizer._create( word_contextualizer, hidden_size=hidden_size) - self.word_comparator = WordComparator._create( - word_comparator, hidden_size=hidden_size) - self.word_aggregator = WordAggregator._create( - word_aggregator, hidden_size=hidden_size) + self.word_comparator = WordComparator._create(word_comparator, + hidden_size=hidden_size) + self.word_aggregator = WordAggregator._create(word_aggregator, + hidden_size=hidden_size) def _forward(self, left_input, right_input): r"""The forward function of attribute summarizer. @@ -535,16 +540,20 @@ def _forward(self, left_input, right_input): left_compared, right_compared = left_contextualized, right_contextualized if self.word_comparator: - left_compared = self.word_comparator( - left_contextualized, right_contextualized, left_input, right_input) - right_compared = self.word_comparator( - right_contextualized, left_contextualized, right_input, left_input) + left_compared = self.word_comparator(left_contextualized, + right_contextualized, + left_input, right_input) + right_compared = self.word_comparator(right_contextualized, + left_contextualized, + right_input, left_input) left_aggregator_context = right_input right_aggregator_context = left_input - left_aggregated = self.word_aggregator(left_compared, left_aggregator_context) - right_aggregated = self.word_aggregator(right_compared, right_aggregator_context) + left_aggregated = self.word_aggregator(left_compared, + left_aggregator_context) + right_aggregated = self.word_aggregator(right_compared, + right_aggregator_context) return left_aggregated, right_aggregated @classmethod @@ -603,7 +612,6 @@ class WordContextualizer(dm.modules.LazyModule): for details. Sub-classes that implement various options for this module are defined in :mod:`deepmatcher.word_contextualizers`. """ - @classmethod def _create(cls, arg, **kwargs): r"""Create a word contextualizer object. @@ -644,7 +652,6 @@ class WordComparator(dm.modules.LazyModule): for details. Sub-classes that implement various options for this module are defined in :mod:`deepmatcher.word_comparators`. """ - @classmethod def _create(cls, arg, **kwargs): r"""Create a word comparator object. @@ -660,9 +667,10 @@ def _create(cls, arg, **kwargs): """ if isinstance(arg, six.string_types): parts = arg.split('-') - if (parts[1] == 'attention' and - dm.modules.AlignmentNetwork.supports_style(parts[0])): - wc = dm.word_comparators.Attention(alignment_network=parts[0], **kwargs) + if (parts[1] == 'attention' + and dm.modules.AlignmentNetwork.supports_style(parts[0])): + wc = dm.word_comparators.Attention(alignment_network=parts[0], + **kwargs) else: raise ValueError('Unknown Word Comparator name.') else: @@ -684,7 +692,6 @@ class WordAggregator(dm.modules.LazyModule): for details. Sub-classes that implement various options for this module are defined in :mod:`deepmatcher.word_aggregators`. """ - @classmethod def _create(cls, arg, **kwargs): r""" @@ -702,14 +709,18 @@ def _create(cls, arg, **kwargs): assert arg is not None if isinstance(arg, six.string_types): parts = arg.split('-') - if (parts[-1] == 'pool' and - dm.word_aggregators.Pool.supports_style('-'.join(parts[:-1]))): + if (parts[-1] == 'pool' + and dm.word_aggregators.Pool.supports_style('-'.join( + parts[:-1]))): seq = [] - seq.append(dm.modules.Lambda(lambda x1, x2: x1)) # Ignore the context. - seq.append(dm.word_aggregators.Pool(style='-'.join(parts[:-1]))) + seq.append(dm.modules.Lambda( + lambda x1, x2: x1)) # Ignore the context. + seq.append( + dm.word_aggregators.Pool(style='-'.join(parts[:-1]))) # Make lazy module. - wa = dm.modules.LazyModuleFn(lambda: dm.modules.MultiSequential(*seq)) + wa = dm.modules.LazyModuleFn( + lambda: dm.modules.MultiSequential(*seq)) elif arg == 'attention-with-rnn': wa = dm.word_aggregators.AttentionWithRNN(**kwargs) else: @@ -739,13 +750,13 @@ class Classifier(nn.Sequential): The size of the hidden representation generated by the transformation network. If None, uses the size of the input vector to this module as the hidden size. """ - def __init__(self, transform_network, hidden_size=None): super(Classifier, self).__init__() if transform_network: - self.add_module('transform', - dm.modules._transform_module(transform_network, hidden_size)) - self.add_module('softmax_transform', - dm.modules.Transform( - '1-layer', non_linearity=None, output_size=2)) + self.add_module( + 'transform', + dm.modules._transform_module(transform_network, hidden_size)) + self.add_module( + 'softmax_transform', + dm.modules.Transform('1-layer', non_linearity=None, output_size=2)) self.add_module('softmax', nn.LogSoftmax(dim=1)) diff --git a/deepmatcher/models/modules.py b/deepmatcher/models/modules.py index 6bf1433..5ffea39 100644 --- a/deepmatcher/models/modules.py +++ b/deepmatcher/models/modules.py @@ -56,7 +56,6 @@ class LazyModule(nn.Module): defined. Whatever you typically define in the forward function of a PyTorch module, you may define it here. All subclasses must override this method. """ - def __init__(self, *args, **kwargs): """Construct a :class:`LazyModule`. DO NOT OVERRIDE this method. @@ -100,12 +99,12 @@ def forward(self, input, *args, **kwargs): """ if not self._initialized: try: - self._init( - *self._init_args, - input_size=self._get_input_size(input, *args, **kwargs), - **self._init_kwargs) + self._init(*self._init_args, + input_size=self._get_input_size( + input, *args, **kwargs), + **self._init_kwargs) except TypeError as e: - logger.debug('Got exception when passing input size: ' + str(e)) + logger.debug('Got exception when passing input size: %s', e) self._init(*self._init_args, **self._init_kwargs) for fn in self._fns: super(LazyModule, self)._apply(fn) @@ -174,7 +173,6 @@ class NoMeta(nn.Module): Args: module (:class:`~torch.nn.Module`): The module to wrap. """ - def __init__(self, module): super(NoMeta, self).__init__() self.module = module @@ -182,7 +180,8 @@ def __init__(self, module): def forward(self, *args): module_args = [] for arg in args: - module_args.append(arg.data if isinstance(arg, AttrTensor) else arg) + module_args.append( + arg.data if isinstance(arg, AttrTensor) else arg) results = self.module(*module_args) @@ -190,12 +189,14 @@ def forward(self, *args): return results else: if not isinstance(results, tuple): - results = (results,) + results = (results, ) - if len(results) != len(args) and len(results) != 1 and len(args) != 1: + if len(results) != len(args) and len(results) != 1 and len( + args) != 1: raise ValueError( 'Number of inputs must equal number of outputs, or ' - 'number of inputs must be 1 or number of outputs must be 1.') + 'number of inputs must be 1 or number of outputs must be 1.' + ) results_with_meta = [] for i in range(len(results)): @@ -232,7 +233,6 @@ def forward(self, x1, x2): return y1, y2 """ - def __getitem__(self, name): return getattr(self, name) @@ -249,12 +249,12 @@ class MultiSequential(nn.Sequential): This is an extenstion of PyTorch's :class:`~torch.nn.Sequential` module that allows each module to have multiple inputs and / or outputs. """ - def forward(self, *inputs): modules = list(self._modules.values()) inputs = modules[0](*inputs) for module in modules[1:]: - if isinstance(inputs, tuple) and not isinstance(inputs, AttrTensor): + if isinstance(inputs, + tuple) and not isinstance(inputs, AttrTensor): inputs = module(*inputs) else: inputs = module(inputs) @@ -278,7 +278,6 @@ class LazyModuleFn(LazyModule): *kwargs: Keyword arguments to the function `fn`. """ - def _init(self, fn, *args, **kwargs): self.module = fn(*args, **kwargs) @@ -396,15 +395,14 @@ def _init(self, rnn_in_size = input_size for g in range(rnn_groups): self.rnn_groups.append( - self._get_rnn_module( - unit_type, - input_size=rnn_in_size, - hidden_size=hidden_size, - num_layers=layers_per_group, - batch_first=True, - dropout=dropout, - bidirectional=bidirectional, - **kwargs)) + self._get_rnn_module(unit_type, + input_size=rnn_in_size, + hidden_size=hidden_size, + num_layers=layers_per_group, + batch_first=True, + dropout=dropout, + bidirectional=bidirectional, + **kwargs)) if g != rnn_groups: self.dropouts.append(nn.Dropout(dropout)) @@ -538,11 +536,13 @@ def _forward(self, input, context): elif self.style == 'general': return torch.bmm( input, # batch x len1 x input_size - self.transform(context).transpose(1, 2)) # batch x input_size x len2 + self.transform(context).transpose( + 1, 2)) # batch x input_size x len2 elif self.style == 'decomposable': return torch.bmm( self.transform(input), # batch x hidden_size x len2 - self.transform(context).transpose(1, 2)) # batch x hidden_size x len2 + self.transform(context).transpose( + 1, 2)) # batch x hidden_size x len2 # elif self.style in ['concat', 'concat_dot']: # # batch x len1 x 1 x output_size # input_transformed = self.input_transform(input).unsqueeze(2) @@ -569,7 +569,6 @@ class Lambda(nn.Module): more Pytorch :class:`~torch.Tensor` s and return one or more :class:`~torch.Tensor` s. """ - def __init__(self, lambd): super(Lambda, self).__init__() self.lambd = lambd @@ -699,7 +698,8 @@ def _forward(self, input_with_meta): if self.style == 'last': lengths = input_with_meta.lengths - lasts = Variable(lengths.view(-1, 1, 1).repeat(1, 1, input.size(2))) - 1 + lasts = Variable( + lengths.view(-1, 1, 1).repeat(1, 1, input.size(2))) - 1 output = torch.gather(input, 1, lasts).squeeze(1).float() elif self.style == 'last-simple': output = input[:, input.size(1), :] @@ -729,21 +729,26 @@ def _forward(self, input_with_meta): mask = mask.unsqueeze(2) # Make it broadcastable. input.data.masked_fill_(~mask, 0) - lengths = Variable(input_with_meta.lengths.clamp(min=1).unsqueeze(1).float()) + lengths = Variable( + input_with_meta.lengths.clamp(min=1).unsqueeze(1).float()) if self.style == 'avg': output = input.sum(1) / lengths elif self.style == 'divsqrt': output = input.sum(1) / lengths.sqrt() elif self.style == 'inv-freq-avg': - inv_probs = self.alpha / (input_with_meta.word_probs + self.alpha) + inv_probs = self.alpha / (input_with_meta.word_probs + + self.alpha) weighted = input * Variable(inv_probs.unsqueeze(2)) output = weighted.sum(1) / lengths.sqrt() elif self.style == 'sif': - inv_probs = self.alpha / (input_with_meta.word_probs + self.alpha) + inv_probs = self.alpha / (input_with_meta.word_probs + + self.alpha) weighted = input * Variable(inv_probs.unsqueeze(2)) v = (weighted.sum(1) / lengths.sqrt()) - pc = Variable(input_with_meta.pc).unsqueeze(0).repeat(v.shape[0], 1) - proj_v_on_pc = torch.bmm(v.unsqueeze(1), pc.unsqueeze(2)).squeeze(2) * pc + pc = Variable(input_with_meta.pc).unsqueeze(0).repeat( + v.shape[0], 1) + proj_v_on_pc = torch.bmm(v.unsqueeze(1), + pc.unsqueeze(2)).squeeze(2) * pc output = v - proj_v_on_pc else: raise NotImplementedError(self.style + ' is not implemented.') @@ -780,12 +785,20 @@ class Merge(LazyModule): """ _style_map = { - 'concat': lambda *args: torch.cat(args, args[0].dim() - 1), - 'diff': lambda x, y: x - y, - 'abs-diff': lambda x, y: torch.abs(x - y), - 'concat-diff': lambda x, y: torch.cat((x, y, x - y), x.dim() - 1), - 'concat-abs-diff': lambda x, y: torch.cat((x, y, torch.abs(x - y)), x.dim() - 1), - 'mul': lambda x, y: torch.mul(x, y) + 'concat': + lambda *args: torch.cat(args, args[0].dim() - 1), + 'diff': + lambda x, y: x - y, + 'abs-diff': + lambda x, y: torch.abs(x - y), + 'concat-diff': + lambda x, y: torch.cat((x, y, x - y), + x.dim() - 1), + 'concat-abs-diff': + lambda x, y: torch.cat((x, y, torch.abs(x - y)), + x.dim() - 1), + 'mul': + lambda x, y: torch.mul(x, y) } @classmethod @@ -828,7 +841,11 @@ class Bypass(LazyModule): def supports_style(cls, style): return style.lower() in cls._supported_styles - def _init(self, style, residual_scale=True, highway_bias=-2, input_size=None): + def _init(self, + style, + residual_scale=True, + highway_bias=-2, + input_size=None): assert self.supports_style(style) self.style = style.lower() self.residual_scale = residual_scale @@ -847,12 +864,12 @@ def _forward(self, transformed, raw): padded = F.pad(raw, (0, tsize - rsize % tsize)) else: padded = raw - adjusted_raw = padded.view(*raw.shape[:-1], -1, tsize).sum(-2) * math.sqrt( - tsize / rsize) + adjusted_raw = padded.view( + *raw.shape[:-1], -1, tsize).sum(-2) * math.sqrt(tsize / rsize) elif tsize > rsize: multiples = math.ceil(tsize / rsize) - adjusted_raw = raw.repeat(*([1] * (raw.dim() - 1)), multiples).narrow( - -1, 0, tsize) + adjusted_raw = raw.repeat(*([1] * (raw.dim() - 1)), + multiples).narrow(-1, 0, tsize) if self.style == 'residual': res = transformed + adjusted_raw @@ -860,7 +877,8 @@ def _forward(self, transformed, raw): res *= math.sqrt(0.5) return res elif self.style == 'highway': - transform_gate = torch.sigmoid(self.highway_gate(raw) + self.highway_bias) + transform_gate = torch.sigmoid( + self.highway_gate(raw) + self.highway_bias) carry_gate = 1 - transform_gate return transform_gate * transformed + carry_gate * adjusted_raw @@ -971,7 +989,8 @@ def _init(self, self.transforms = nn.ModuleList() self.bypass_networks = nn.ModuleList() - assert (non_linearity is None or self.supports_nonlinearity(non_linearity)) + assert (non_linearity is None + or self.supports_nonlinearity(non_linearity)) self.non_linearity = non_linearity.lower() if non_linearity else None transform_in_size = input_size @@ -979,7 +998,8 @@ def _init(self, for layer in range(layers): if layer == layers - 1: transform_out_size = output_size - self.transforms.append(nn.Linear(transform_in_size, transform_out_size)) + self.transforms.append( + nn.Linear(transform_in_size, transform_out_size)) self.bypass_networks.append(_bypass_module(bypass_network)) transform_in_size = transform_out_size @@ -1013,8 +1033,10 @@ def _bypass_module(op): def _transform_module(op, hidden_size, output_size=None): output_size = output_size or hidden_size - module = _utils.get_module( - Transform, op, hidden_size=hidden_size, output_size=output_size) + module = _utils.get_module(Transform, + op, + hidden_size=hidden_size, + output_size=output_size) if module: module.expect_signature('[AxB] -> [AxC]') module.expect_signature('[AxBxC] -> [AxBxD]') @@ -1022,7 +1044,9 @@ def _transform_module(op, hidden_size, output_size=None): def _alignment_module(op, hidden_size): - module = _utils.get_module( - AlignmentNetwork, op, hidden_size=hidden_size, required=True) + module = _utils.get_module(AlignmentNetwork, + op, + hidden_size=hidden_size, + required=True) module.expect_signature('[AxBxC, AxDxC] -> [AxBxD]') return module diff --git a/deepmatcher/optim.py b/deepmatcher/optim.py index b0833a6..807c1ae 100644 --- a/deepmatcher/optim.py +++ b/deepmatcher/optim.py @@ -33,8 +33,11 @@ class imbalance in the training set. well as** over dimensions. However, if ``False`` the losses are instead summed. This is a keyword only parameter. """ - - def __init__(self, label_smoothing=0, weight=None, num_classes=2, **kwargs): + def __init__(self, + label_smoothing=0, + weight=None, + num_classes=2, + **kwargs): super(SoftNLLLoss, self).__init__(**kwargs) self.label_smoothing = label_smoothing self.confidence = 1 - self.label_smoothing @@ -79,7 +82,6 @@ class Optimizer(object): adagrad_accum (float, optional): Initialization hyperparameter for adagrad. """ - def __init__(self, method='adam', lr=0.001, @@ -126,14 +128,16 @@ def set_parameters(self, params): elif self.method == 'adadelta': self.base_optimizer = optim.Adadelta(self.params, lr=self.lr) elif self.method == 'adam': - self.base_optimizer = optim.Adam( - self.params, lr=self.lr, betas=self.betas, eps=1e-9) + self.base_optimizer = optim.Adam(self.params, + lr=self.lr, + betas=self.betas, + eps=1e-9) else: raise RuntimeError("Invalid optim method: " + self.method) def _set_rate(self, lr): for param_group in self.base_optimizer.param_groups: - param_group['lr'] = self.lr + param_group['lr'] = lr def step(self): """Update the model parameters based on current gradients. diff --git a/deepmatcher/runner.py b/deepmatcher/runner.py index f2eb7e7..962655e 100644 --- a/deepmatcher/runner.py +++ b/deepmatcher/runner.py @@ -33,7 +33,6 @@ class Statistics(object): * Recall * Accuracy """ - def __init__(self): self.loss_sum = 0 self.examples = 0 @@ -78,42 +77,42 @@ class Runner(object): This class implements routines to train, evaluate and make predictions from models. """ - @staticmethod def _print_stats(name, epoch, batch, n_batches, stats, cum_stats): """Write out batch statistics to stdout. """ - print((' | {name} | [{epoch}][{batch:4d}/{n_batches}] || Loss: {loss:7.4f} |' - ' F1: {f1:7.2f} | Prec: {prec:7.2f} | Rec: {rec:7.2f} ||' - ' Cum. F1: {cf1:7.2f} | Cum. Prec: {cprec:7.2f} | Cum. Rec: {crec:7.2f} ||' - ' Ex/s: {eps:6.1f}').format( - name=name, - epoch=epoch, - batch=batch, - n_batches=n_batches, - loss=stats.loss(), - f1=stats.f1(), - prec=stats.precision(), - rec=stats.recall(), - cf1=cum_stats.f1(), - cprec=cum_stats.precision(), - crec=cum_stats.recall(), - eps=cum_stats.examples_per_sec())) + print(( + ' | {name} | [{epoch}][{batch:4d}/{n_batches}] || Loss: {loss:7.4f} |' + ' F1: {f1:7.2f} | Prec: {prec:7.2f} | Rec: {rec:7.2f} ||' + ' Cum. F1: {cf1:7.2f} | Cum. Prec: {cprec:7.2f} | Cum. Rec: {crec:7.2f} ||' + ' Ex/s: {eps:6.1f}').format(name=name, + epoch=epoch, + batch=batch, + n_batches=n_batches, + loss=stats.loss(), + f1=stats.f1(), + prec=stats.precision(), + rec=stats.recall(), + cf1=cum_stats.f1(), + cprec=cum_stats.precision(), + crec=cum_stats.recall(), + eps=cum_stats.examples_per_sec())) @staticmethod def _print_final_stats(epoch, runtime, datatime, stats): """Write out epoch statistics to stdout. """ - print(('Finished Epoch {epoch} || Run Time: {runtime:6.1f} | ' - 'Load Time: {datatime:6.1f} || F1: {f1:6.2f} | Prec: {prec:6.2f} | ' - 'Rec: {rec:6.2f} || Ex/s: {eps:6.2f}\n').format( - epoch=epoch, - runtime=runtime, - datatime=datatime, - f1=stats.f1(), - prec=stats.precision(), - rec=stats.recall(), - eps=stats.examples_per_sec())) + print(( + 'Finished Epoch {epoch} || Run Time: {runtime:6.1f} | ' + 'Load Time: {datatime:6.1f} || F1: {f1:6.2f} | Prec: {prec:6.2f} | ' + 'Rec: {rec:6.2f} || Ex/s: {eps:6.2f}\n').format( + epoch=epoch, + runtime=runtime, + datatime=datatime, + f1=stats.f1(), + prec=stats.precision(), + rec=stats.recall(), + eps=stats.examples_per_sec())) @staticmethod def _set_pbar_status(pbar, stats, cum_stats): @@ -163,13 +162,12 @@ def _run(run_type, device = 'cuda' sort_in_buckets = train - run_iter = MatchingIterator( - dataset, - model.meta, - train, - batch_size=batch_size, - device=device, - sort_in_buckets=sort_in_buckets) + run_iter = MatchingIterator(dataset, + model.meta, + train, + batch_size=batch_size, + device=device, + sort_in_buckets=sort_in_buckets) model = model.to(device) if criterion: @@ -198,14 +196,15 @@ def _run(run_type, # The tqdm-bar for Jupyter notebook is under development. if progress_style == 'tqdm-bar': - pbar = tqdm( - total=len(run_iter) // log_freq, - bar_format='{l_bar}{bar}{postfix}', - file=sys.stdout) + pbar = tqdm(total=len(run_iter) // log_freq, + bar_format='{l_bar}{bar}{postfix}', + file=sys.stdout) # Use the pyprind bar as the default progress bar. if progress_style == 'bar': - pbar = pyprind.ProgBar(len(run_iter) // log_freq, bar_char='█', width=30) + pbar = pyprind.ProgBar(len(run_iter) // log_freq, + bar_char='█', + width=30) for batch_idx, batch in enumerate(run_iter): batch_start = time.time() @@ -222,7 +221,8 @@ def _run(run_type, loss = criterion(output, getattr(batch, label_attr)) if hasattr(batch, label_attr): - scores = Runner._compute_scores(output, getattr(batch, label_attr)) + scores = Runner._compute_scores(output, + getattr(batch, label_attr)) else: scores = [0] * 4 @@ -235,8 +235,8 @@ def _run(run_type, if (batch_idx + 1) % log_freq == 0: if progress_style == 'log': - Runner._print_stats(run_type, epoch + 1, batch_idx + 1, len(run_iter), - stats, cum_stats) + Runner._print_stats(run_type, epoch + 1, batch_idx + 1, + len(run_iter), stats, cum_stats) elif progress_style == 'tqdm-bar': pbar.update() Runner._set_pbar_status(pbar, stats, cum_stats) @@ -304,9 +304,10 @@ def train(model, if criterion is None: if pos_weight is not None: assert pos_weight < 2 - warnings.warn('"pos_weight" parameter is deprecated and will be removed ' - 'in a later release, please use "pos_neg_ratio" instead', - DeprecationWarning) + warnings.warn( + '"pos_weight" parameter is deprecated and will be removed ' + 'in a later release, please use "pos_neg_ratio" instead', + DeprecationWarning) assert pos_neg_ratio is None else: if pos_neg_ratio is None: @@ -322,7 +323,8 @@ def train(model, optimizer = optimizer or Optimizer() if model.optimizer_state is not None: - model.optimizer.base_optimizer.load_state_dict(model.optimizer_state) + model.optimizer.base_optimizer.load_state_dict( + model.optimizer_state) if model.epoch is None: epochs_range = range(epochs) @@ -335,10 +337,19 @@ def train(model, for epoch in epochs_range: model.epoch = epoch - Runner._run( - 'TRAIN', model, train_dataset, criterion, optimizer, train=True, **kwargs) - - score = Runner._run('EVAL', model, validation_dataset, train=False, **kwargs) + Runner._run('TRAIN', + model, + train_dataset, + criterion, + optimizer, + train=True, + **kwargs) + + score = Runner._run('EVAL', + model, + validation_dataset, + train=False, + **kwargs) optimizer.update_learning_rate(score, epoch + 1) model.optimizer_state = optimizer.base_optimizer.state_dict() @@ -354,7 +365,8 @@ def train(model, model.save_state(best_save_path) print('Done.') - if save_every_prefix is not None and (epoch + 1) % save_every_freq == 0: + if save_every_prefix is not None and (epoch + + 1) % save_every_freq == 0: print('Saving epoch model...') save_path = '{prefix}_ep{epoch}.pth'.format( prefix=save_every_prefix, epoch=epoch + 1) @@ -399,9 +411,13 @@ def predict(model, dataset, output_attributes=False, **kwargs): model = copy.deepcopy(model) model._reset_embeddings(dataset.vocabs) - predictions = Runner._run( - 'PREDICT', model, dataset, return_predictions=True, **kwargs) - pred_table = pd.DataFrame(predictions, columns=(dataset.id_field, 'match_score')) + predictions = Runner._run('PREDICT', + model, + dataset, + return_predictions=True, + **kwargs) + pred_table = pd.DataFrame(predictions, + columns=(dataset.id_field, 'match_score')) pred_table = pred_table.set_index(dataset.id_field) if output_attributes: diff --git a/deepmatcher/utils.py b/deepmatcher/utils.py index 24aa947..6094070 100644 --- a/deepmatcher/utils.py +++ b/deepmatcher/utils.py @@ -1,5 +1,4 @@ class Bunch: - def __init__(self, **kwds): self.__dict__.update(kwds) diff --git a/setup.cfg b/setup.cfg index 9886c7c..cfa0c9c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,3 +4,7 @@ # need to generate separate wheels for each Python version that you # support. universal=1 + + +[yapf] +based_on_style = pep8 \ No newline at end of file diff --git a/test/test_dataset.py b/test/test_dataset.py index 428311b..81921c2 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -15,18 +15,22 @@ from urllib.parse import urljoin from urllib.request import pathname2url - # import nltk # nltk.download('perluniprops') # nltk.download('nonbreaking_prefixes') class ClassMatchingDatasetTestCases(unittest.TestCase): - def test_init_1(self): fields = [('left_a', MatchingField()), ('right_a', MatchingField())] - col_naming = {'id': 'id', 'label': 'label', 'left': 'left', 'right': 'right'} - path = os.path.join(test_dir_path, 'test_datasets', 'sample_table_small.csv') + col_naming = { + 'id': 'id', + 'label': 'label', + 'left': 'left', + 'right': 'right' + } + path = os.path.join(test_dir_path, 'test_datasets', + 'sample_table_small.csv') md = MatchingDataset(fields, col_naming, path=path) self.assertEqual(md.id_field, 'id') self.assertEqual(md.label_field, 'label') @@ -37,23 +41,22 @@ def test_init_1(self): class MatchingDatasetSplitsTestCases(unittest.TestCase): - def setUp(self): self.data_dir = os.path.join(test_dir_path, 'test_datasets') self.train = 'test_train.csv' self.validation = 'test_valid.csv' self.test = 'test_test.csv' self.cache_name = 'test_cacheddata.pth' - with io.open( - os.path.expanduser(os.path.join(self.data_dir, self.train)), - encoding="utf8") as f: + with io.open(os.path.expanduser(os.path.join(self.data_dir, + self.train)), + encoding="utf8") as f: header = next(unicode_csv_reader(f)) id_attr = 'id' label_attr = 'label' ignore_columns = ['left_id', 'right_id'] - self.fields = _make_fields(header, id_attr, label_attr, ignore_columns, True, - 'nltk', False) + self.fields = _make_fields(header, id_attr, label_attr, ignore_columns, + True, 'nltk', False) self.column_naming = { 'id': id_attr, @@ -68,76 +71,70 @@ def tearDown(self): os.remove(cache_name) def test_splits_1(self): - datasets = MatchingDataset.splits( - self.data_dir, - self.train, - self.validation, - self.test, - self.fields, - None, - None, - self.column_naming, - self.cache_name, - train_pca=False) + datasets = MatchingDataset.splits(self.data_dir, + self.train, + self.validation, + self.test, + self.fields, + None, + None, + self.column_naming, + self.cache_name, + train_pca=False) @raises(MatchingDataset.CacheStaleException) def test_splits_2(self): - datasets = MatchingDataset.splits( - self.data_dir, - self.train, - self.validation, - self.test, - self.fields, - None, - None, - self.column_naming, - self.cache_name, - train_pca=False) - - datasets_2 = MatchingDataset.splits( - self.data_dir, - 'sample_table_small.csv', - self.validation, - self.test, - self.fields, - None, - None, - self.column_naming, - self.cache_name, - True, - False, - train_pca=False) + datasets = MatchingDataset.splits(self.data_dir, + self.train, + self.validation, + self.test, + self.fields, + None, + None, + self.column_naming, + self.cache_name, + train_pca=False) + + datasets_2 = MatchingDataset.splits(self.data_dir, + 'sample_table_small.csv', + self.validation, + self.test, + self.fields, + None, + None, + self.column_naming, + self.cache_name, + True, + False, + train_pca=False) def test_splits_3(self): - datasets = MatchingDataset.splits( - self.data_dir, - self.train, - self.validation, - self.test, - self.fields, - None, - None, - self.column_naming, - self.cache_name, - train_pca=False) - - datasets_2 = MatchingDataset.splits( - self.data_dir, - self.train, - self.validation, - self.test, - self.fields, - None, - None, - self.column_naming, - self.cache_name, - False, - False, - train_pca=False) + datasets = MatchingDataset.splits(self.data_dir, + self.train, + self.validation, + self.test, + self.fields, + None, + None, + self.column_naming, + self.cache_name, + train_pca=False) + + datasets_2 = MatchingDataset.splits(self.data_dir, + self.train, + self.validation, + self.test, + self.fields, + None, + None, + self.column_naming, + self.cache_name, + False, + False, + train_pca=False) class DataframeSplitTestCases(unittest.TestCase): - def test_split_1(self): labeled_path = os.path.join(test_dir_path, 'test_datasets', 'sample_table_large.csv') @@ -200,13 +197,13 @@ def test_split_2(self): class GetRawTableTestCases(unittest.TestCase): - def test_get_raw_table(self): vectors_cache_dir = '.cache' if os.path.exists(vectors_cache_dir): shutil.rmtree(vectors_cache_dir) - data_cache_path = os.path.join(test_dir_path, 'test_datasets', 'cacheddata.pth') + data_cache_path = os.path.join(test_dir_path, 'test_datasets', + 'cacheddata.pth') if os.path.exists(data_cache_path): os.remove(data_cache_path) @@ -215,17 +212,17 @@ def test_get_raw_table(self): url_base = urljoin('file:', pathname2url(vec_dir)) + os.path.sep ft = FastText(filename, url_base=url_base, cache=vectors_cache_dir) - train = process( - path=os.path.join(test_dir_path, 'test_datasets'), - train='sample_table_small.csv', - id_attr='id', - embeddings=ft, - embeddings_cache_path='', - pca=False) + train = process(path=os.path.join(test_dir_path, 'test_datasets'), + train='sample_table_small.csv', + id_attr='id', + embeddings=ft, + embeddings_cache_path='', + pca=False) train_raw = train.get_raw_table() ori_train = pd.read_csv( - os.path.join(test_dir_path, 'test_datasets', 'sample_table_small.csv')) + os.path.join(test_dir_path, 'test_datasets', + 'sample_table_small.csv')) self.assertEqual(set(train_raw.columns), set(ori_train.columns)) if os.path.exists(data_cache_path): diff --git a/test/test_field.py b/test/test_field.py index a86803b..621a69e 100644 --- a/test/test_field.py +++ b/test/test_field.py @@ -20,7 +20,6 @@ class ClassFastTextTestCases(unittest.TestCase): - def test_init_1(self): vectors_cache_dir = '.cache' if os.path.exists(vectors_cache_dir): @@ -38,7 +37,6 @@ def test_init_1(self): class ClassFastTextBinaryTestCases(unittest.TestCase): - @raises(RuntimeError) def test_init_1(self): vectors_cache_dir = '.cache' @@ -47,8 +45,11 @@ def test_init_1(self): pathdir = os.path.abspath(os.path.join(test_dir_path, 'test_datasets')) filename = 'fasttext_sample.vec.zip' - url_base = urljoin('file:', pathname2url(os.path.join(pathdir, filename))) - mftb = FastTextBinary(filename, url_base=url_base, cache=vectors_cache_dir) + url_base = urljoin('file:', + pathname2url(os.path.join(pathdir, filename))) + mftb = FastTextBinary(filename, + url_base=url_base, + cache=vectors_cache_dir) if os.path.exists(vectors_cache_dir): shutil.rmtree(vectors_cache_dir) @@ -61,8 +62,11 @@ def test_init_2(self): pathdir = os.path.abspath(os.path.join(test_dir_path, 'test_datasets')) filename = 'fasttext_sample_not_exist.vec.zip' - url_base = urljoin('file:', pathname2url(os.path.join(pathdir, filename))) - mftb = FastTextBinary(filename, url_base=url_base, cache=vectors_cache_dir) + url_base = urljoin('file:', + pathname2url(os.path.join(pathdir, filename))) + mftb = FastTextBinary(filename, + url_base=url_base, + cache=vectors_cache_dir) if os.path.exists(vectors_cache_dir): shutil.rmtree(vectors_cache_dir) @@ -75,15 +79,17 @@ def test_init_3(self): pathdir = os.path.abspath(os.path.join(test_dir_path, 'test_datasets')) filename = 'fasttext_sample_not_exist.gz' - url_base = urljoin('file:', pathname2url(os.path.join(pathdir, filename))) - mftb = FastTextBinary(filename, url_base=url_base, cache=vectors_cache_dir) + url_base = urljoin('file:', + pathname2url(os.path.join(pathdir, filename))) + mftb = FastTextBinary(filename, + url_base=url_base, + cache=vectors_cache_dir) if os.path.exists(vectors_cache_dir): shutil.rmtree(vectors_cache_dir) class ClassMatchingFieldTestCases(unittest.TestCase): - def test_init_1(self): mf = MatchingField() self.assertTrue(mf.sequential) @@ -92,7 +98,8 @@ def test_init_2(self): mf = MatchingField() seq = 'Hello, This is a test sequence for tokenizer.' tok_seq = [ - 'Hello', ',', 'This', 'is', 'a', 'test', 'sequence', 'for', 'tokenizer', '.' + 'Hello', ',', 'This', 'is', 'a', 'test', 'sequence', 'for', + 'tokenizer', '.' ] self.assertEqual(mf.tokenize(seq), tok_seq) @@ -175,7 +182,12 @@ def test_extend_vocab_1(self): mf = MatchingField() lf = MatchingField(id=True, sequential=False) fields = [('id', lf), ('left_a', mf), ('right_a', mf), ('label', lf)] - col_naming = {'id': 'id', 'label': 'label', 'left': 'left_', 'right': 'right_'} + col_naming = { + 'id': 'id', + 'label': 'label', + 'left': 'left_', + 'right': 'right_' + } pathdir = os.path.abspath(os.path.join(test_dir_path, 'test_datasets')) filename = 'fasttext_sample.vec' @@ -183,7 +195,8 @@ def test_extend_vocab_1(self): url_base = urljoin('file:', pathname2url(file)) vecs = Vectors(name=filename, cache=vectors_cache_dir, url=url_base) - data_path = os.path.join(test_dir_path, 'test_datasets', 'sample_table_small.csv') + data_path = os.path.join(test_dir_path, 'test_datasets', + 'sample_table_small.csv') md = MatchingDataset(fields, col_naming, path=data_path) mf.build_vocab() @@ -194,7 +207,6 @@ def test_extend_vocab_1(self): class TestResetVectorCache(unittest.TestCase): - def test_reset_vector_cache_1(self): mf = MatchingField() reset_vector_cache() @@ -202,7 +214,6 @@ def test_reset_vector_cache_1(self): class ClassMatchingVocabTestCases(unittest.TestCase): - def test_extend_vectors_1(self): vectors_cache_dir = '.cache' if os.path.exists(vectors_cache_dir): diff --git a/test/test_integration.py b/test/test_integration.py index 0261f4d..2ee4a22 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -16,6 +16,7 @@ from test import test_dir_path + class ModelTrainSaveLoadTest(unittest.TestCase): def setUp(self): self.vectors_cache_dir = '.cache' @@ -23,14 +24,16 @@ def setUp(self): shutil.rmtree(self.vectors_cache_dir) self.data_cache_path = os.path.join(test_dir_path, 'test_datasets', - 'train_cache.pth') + 'train_cache.pth') if os.path.exists(self.data_cache_path): os.remove(self.data_cache_path) vec_dir = os.path.abspath(os.path.join(test_dir_path, 'test_datasets')) filename = 'fasttext_sample.vec.zip' url_base = urljoin('file:', pathname2url(vec_dir)) + os.path.sep - ft = FastText(filename, url_base=url_base, cache=self.vectors_cache_dir) + ft = FastText(filename, + url_base=url_base, + cache=self.vectors_cache_dir) self.train, self.valid, self.test = process( path=os.path.join(test_dir_path, 'test_datasets'), @@ -52,13 +55,12 @@ def tearDown(self): def test_sif(self): model_save_path = 'sif_model.pth' model = MatchingModel(attr_summarizer='sif') - model.run_train( - self.train, - self.valid, - epochs=1, - batch_size=8, - best_save_path= model_save_path, - pos_neg_ratio=3) + model.run_train(self.train, + self.valid, + epochs=1, + batch_size=8, + best_save_path=model_save_path, + pos_neg_ratio=3) s1 = model.run_eval(self.test) model2 = MatchingModel(attr_summarizer='sif') @@ -73,13 +75,12 @@ def test_sif(self): def test_rnn(self): model_save_path = 'rnn_model.pth' model = MatchingModel(attr_summarizer='rnn') - model.run_train( - self.train, - self.valid, - epochs=1, - batch_size=8, - best_save_path= model_save_path, - pos_neg_ratio=3) + model.run_train(self.train, + self.valid, + epochs=1, + batch_size=8, + best_save_path=model_save_path, + pos_neg_ratio=3) s1 = model.run_eval(self.test) model2 = MatchingModel(attr_summarizer='rnn') @@ -94,13 +95,12 @@ def test_rnn(self): def test_attention(self): model_save_path = 'attention_model.pth' model = MatchingModel(attr_summarizer='attention') - model.run_train( - self.train, - self.valid, - epochs=1, - batch_size=8, - best_save_path= model_save_path, - pos_neg_ratio=3) + model.run_train(self.train, + self.valid, + epochs=1, + batch_size=8, + best_save_path=model_save_path, + pos_neg_ratio=3) s1 = model.run_eval(self.test) @@ -116,13 +116,12 @@ def test_attention(self): def test_hybrid(self): model_save_path = 'hybrid_model.pth' model = MatchingModel(attr_summarizer='hybrid') - model.run_train( - self.train, - self.valid, - epochs=1, - batch_size=8, - best_save_path= model_save_path, - pos_neg_ratio=3) + model.run_train(self.train, + self.valid, + epochs=1, + batch_size=8, + best_save_path=model_save_path, + pos_neg_ratio=3) s1 = model.run_eval(self.test) @@ -137,23 +136,20 @@ def test_hybrid(self): def test_hybrid_self_attention(self): model_save_path = 'self_att_hybrid_model.pth' - model = MatchingModel( - attr_summarizer=attr_summarizers.Hybrid( - word_contextualizer='self-attention')) - - model.run_train( - self.train, - self.valid, - epochs=1, - batch_size=8, - best_save_path= model_save_path, - pos_neg_ratio=3) + model = MatchingModel(attr_summarizer=attr_summarizers.Hybrid( + word_contextualizer='self-attention')) + + model.run_train(self.train, + self.valid, + epochs=1, + batch_size=8, + best_save_path=model_save_path, + pos_neg_ratio=3) s1 = model.run_eval(self.test) - model2 = MatchingModel( - attr_summarizer=attr_summarizers.Hybrid( - word_contextualizer='self-attention')) + model2 = MatchingModel(attr_summarizer=attr_summarizers.Hybrid( + word_contextualizer='self-attention')) model2.load_state(model_save_path) s2 = model2.run_eval(self.test) @@ -170,14 +166,16 @@ def setUp(self): shutil.rmtree(self.vectors_cache_dir) self.data_cache_path = os.path.join(test_dir_path, 'test_datasets', - 'train_cache.pth') + 'train_cache.pth') if os.path.exists(self.data_cache_path): os.remove(self.data_cache_path) vec_dir = os.path.abspath(os.path.join(test_dir_path, 'test_datasets')) filename = 'fasttext_sample.vec.zip' url_base = urljoin('file:', pathname2url(vec_dir)) + os.path.sep - ft = FastText(filename, url_base=url_base, cache=self.vectors_cache_dir) + ft = FastText(filename, + url_base=url_base, + cache=self.vectors_cache_dir) self.train, self.valid, self.test = process( path=os.path.join(test_dir_path, 'test_datasets'), @@ -199,18 +197,17 @@ def tearDown(self): def test_sif(self): model_save_path = 'sif_model.pth' model = MatchingModel(attr_summarizer='sif') - model.run_train( - self.train, - self.valid, - epochs=1, - batch_size=8, - best_save_path= model_save_path, - pos_neg_ratio=3) - - unlabeled = process_unlabeled( - path=os.path.join(test_dir_path, 'test_datasets', 'test_unlabeled.csv'), - trained_model=model, - ignore_columns=('left_id', 'right_id')) + model.run_train(self.train, + self.valid, + epochs=1, + batch_size=8, + best_save_path=model_save_path, + pos_neg_ratio=3) + + unlabeled = process_unlabeled(path=os.path.join( + test_dir_path, 'test_datasets', 'test_unlabeled.csv'), + trained_model=model, + ignore_columns=('left_id', 'right_id')) pred_test = model.run_eval(self.test, return_predictions=True) pred_unlabeled = model.run_prediction(unlabeled) @@ -224,13 +221,12 @@ def test_sif(self): def test_rnn(self): model_save_path = 'rnn_model.pth' model = MatchingModel(attr_summarizer='rnn') - model.run_train( - self.train, - self.valid, - epochs=1, - batch_size=8, - best_save_path= model_save_path, - pos_neg_ratio=3) + model.run_train(self.train, + self.valid, + epochs=1, + batch_size=8, + best_save_path=model_save_path, + pos_neg_ratio=3) unlabeled = process_unlabeled( path=os.path.join(test_dir_path, 'test_datasets', 'test_test.csv'), @@ -249,18 +245,17 @@ def test_rnn(self): def test_attention(self): model_save_path = 'attention_model.pth' model = MatchingModel(attr_summarizer='attention') - model.run_train( - self.train, - self.valid, - epochs=1, - batch_size=8, - best_save_path= model_save_path, - pos_neg_ratio=3) - - unlabeled = process_unlabeled( - path=os.path.join(test_dir_path, 'test_datasets', 'test_unlabeled.csv'), - trained_model=model, - ignore_columns=('left_id', 'right_id')) + model.run_train(self.train, + self.valid, + epochs=1, + batch_size=8, + best_save_path=model_save_path, + pos_neg_ratio=3) + + unlabeled = process_unlabeled(path=os.path.join( + test_dir_path, 'test_datasets', 'test_unlabeled.csv'), + trained_model=model, + ignore_columns=('left_id', 'right_id')) pred_test = model.run_eval(self.test, return_predictions=True) pred_unlabeled = model.run_prediction(unlabeled) @@ -274,18 +269,17 @@ def test_attention(self): def test_hybrid(self): model_save_path = 'hybrid_model.pth' model = MatchingModel(attr_summarizer='hybrid') - model.run_train( - self.train, - self.valid, - epochs=1, - batch_size=8, - best_save_path= model_save_path, - pos_neg_ratio=3) - - unlabeled = process_unlabeled( - path=os.path.join(test_dir_path, 'test_datasets', 'test_unlabeled.csv'), - trained_model=model, - ignore_columns=('left_id', 'right_id')) + model.run_train(self.train, + self.valid, + epochs=1, + batch_size=8, + best_save_path=model_save_path, + pos_neg_ratio=3) + + unlabeled = process_unlabeled(path=os.path.join( + test_dir_path, 'test_datasets', 'test_unlabeled.csv'), + trained_model=model, + ignore_columns=('left_id', 'right_id')) pred_test = model.run_eval(self.test, return_predictions=True) pred_unlabeled = model.run_prediction(unlabeled) diff --git a/test/test_iterator.py b/test/test_iterator.py index e81a99f..52c3ee8 100644 --- a/test/test_iterator.py +++ b/test/test_iterator.py @@ -12,6 +12,7 @@ from urllib.parse import urljoin from urllib.request import pathname2url + class ClassMatchingIteratorTestCases(unittest.TestCase): def test_splits_1(self): vectors_cache_dir = '.cache' @@ -32,17 +33,25 @@ def test_splits_1(self): url_base = urljoin('file:', pathname2url(pathdir)) + os.path.sep ft = FastText(filename, url_base=url_base, cache=vectors_cache_dir) - datasets = process(data_dir, train=train_path, validation=valid_path, - test=test_path, cache=cache_file, embeddings=ft, - id_attr='_id', left_prefix='ltable_', right_prefix='rtable_', - embeddings_cache_path='',pca=False) + datasets = process(data_dir, + train=train_path, + validation=valid_path, + test=test_path, + cache=cache_file, + embeddings=ft, + id_attr='_id', + left_prefix='ltable_', + right_prefix='rtable_', + embeddings_cache_path='', + pca=False) splits = MatchingIterator.splits(datasets, batch_size=16) self.assertEqual(splits[0].batch_size, 16) self.assertEqual(splits[1].batch_size, 16) self.assertEqual(splits[2].batch_size, 16) splits_sorted = MatchingIterator.splits(datasets, - batch_sizes=[16, 32, 64], sort_in_buckets=False) + batch_sizes=[16, 32, 64], + sort_in_buckets=False) self.assertEqual(splits_sorted[0].batch_size, 16) self.assertEqual(splits_sorted[1].batch_size, 32) self.assertEqual(splits_sorted[2].batch_size, 64) @@ -72,18 +81,27 @@ def test_create_batches_1(self): url_base = urljoin('file:', pathname2url(pathdir)) + os.path.sep ft = FastText(filename, url_base=url_base, cache=vectors_cache_dir) - datasets = process(data_dir, train=train_path, validation=valid_path, - test=test_path, cache=cache_file, embeddings=ft, - id_attr='_id', left_prefix='ltable_', right_prefix='rtable_', - embeddings_cache_path='',pca=False) + datasets = process(data_dir, + train=train_path, + validation=valid_path, + test=test_path, + cache=cache_file, + embeddings=ft, + id_attr='_id', + left_prefix='ltable_', + right_prefix='rtable_', + embeddings_cache_path='', + pca=False) splits = MatchingIterator.splits(datasets, batch_size=16) batch_splits = [split.create_batches() for split in splits] sorted_splits = MatchingIterator.splits(datasets, - batch_sizes=[16, 32, 64], sort_in_buckets=False) - batch_sorted_splits = [sorted_split.create_batches() - for sorted_split in sorted_splits] + batch_sizes=[16, 32, 64], + sort_in_buckets=False) + batch_sorted_splits = [ + sorted_split.create_batches() for sorted_split in sorted_splits + ] if os.path.exists(vectors_cache_dir): shutil.rmtree(vectors_cache_dir) diff --git a/test/test_process.py b/test/test_process.py index 412702c..0a5875d 100644 --- a/test/test_process.py +++ b/test/test_process.py @@ -15,68 +15,71 @@ class CheckHeaderTestCases(unittest.TestCase): - def test_check_header_1(self): path = os.path.join(test_dir_path, 'test_datasets') a_dataset = 'sample_table_small.csv' - with io.open( - os.path.expanduser(os.path.join(path, a_dataset)), encoding="utf8") as f: + with io.open(os.path.expanduser(os.path.join(path, a_dataset)), + encoding="utf8") as f: header = next(unicode_csv_reader(f)) self.assertEqual(header, ['id', 'left_a', 'right_a', 'label']) id_attr = 'id' label_attr = 'label' left_prefix = 'left' right_prefix = 'right' - _check_header(header, id_attr, left_prefix, right_prefix, label_attr, []) + _check_header(header, id_attr, left_prefix, right_prefix, label_attr, + []) @raises(ValueError) def test_check_header_2(self): path = os.path.join(test_dir_path, 'test_datasets') a_dataset = 'sample_table_small.csv' - with io.open( - os.path.expanduser(os.path.join(path, a_dataset)), encoding="utf8") as f: + with io.open(os.path.expanduser(os.path.join(path, a_dataset)), + encoding="utf8") as f: header = next(unicode_csv_reader(f)) self.assertEqual(header, ['id', 'left_a', 'right_a', 'label']) id_attr = 'id' label_attr = 'label' left_prefix = 'left' right_prefix = 'bb' - _check_header(header, id_attr, left_prefix, right_prefix, label_attr, []) + _check_header(header, id_attr, left_prefix, right_prefix, label_attr, + []) @raises(ValueError) def test_check_header_3(self): path = os.path.join(test_dir_path, 'test_datasets') a_dataset = 'sample_table_small.csv' - with io.open( - os.path.expanduser(os.path.join(path, a_dataset)), encoding="utf8") as f: + with io.open(os.path.expanduser(os.path.join(path, a_dataset)), + encoding="utf8") as f: header = next(unicode_csv_reader(f)) self.assertEqual(header, ['id', 'left_a', 'right_a', 'label']) id_attr = 'id' label_attr = 'label' left_prefix = 'aa' right_prefix = 'right' - _check_header(header, id_attr, left_prefix, right_prefix, label_attr, []) + _check_header(header, id_attr, left_prefix, right_prefix, label_attr, + []) @raises(ValueError) def test_check_header_5(self): path = os.path.join(test_dir_path, 'test_datasets') a_dataset = 'sample_table_small.csv' - with io.open( - os.path.expanduser(os.path.join(path, a_dataset)), encoding="utf8") as f: + with io.open(os.path.expanduser(os.path.join(path, a_dataset)), + encoding="utf8") as f: header = next(unicode_csv_reader(f)) self.assertEqual(header, ['id', 'left_a', 'right_a', 'label']) id_attr = 'id' label_attr = '' left_prefix = 'left' right_prefix = 'right' - _check_header(header, id_attr, left_prefix, right_prefix, label_attr, []) + _check_header(header, id_attr, left_prefix, right_prefix, label_attr, + []) @raises(AssertionError) def test_check_header_6(self): path = os.path.join(test_dir_path, 'test_datasets') a_dataset = 'sample_table_small.csv' - with io.open( - os.path.expanduser(os.path.join(path, a_dataset)), encoding="utf8") as f: + with io.open(os.path.expanduser(os.path.join(path, a_dataset)), + encoding="utf8") as f: header = next(unicode_csv_reader(f)) self.assertEqual(header, ['id', 'left_a', 'right_a', 'label']) header.pop(1) @@ -84,26 +87,27 @@ def test_check_header_6(self): label_attr = 'label' left_prefix = 'left' right_prefix = 'right' - _check_header(header, id_attr, left_prefix, right_prefix, label_attr, []) + _check_header(header, id_attr, left_prefix, right_prefix, label_attr, + []) class MakeFieldsTestCases(unittest.TestCase): - def test_make_fields_1(self): path = os.path.join(test_dir_path, 'test_datasets') a_dataset = 'sample_table_large.csv' - with io.open( - os.path.expanduser(os.path.join(path, a_dataset)), encoding="utf8") as f: + with io.open(os.path.expanduser(os.path.join(path, a_dataset)), + encoding="utf8") as f: header = next(unicode_csv_reader(f)) self.assertEqual(header, [ '_id', 'ltable_id', 'rtable_id', 'label', 'ltable_Song_Name', - 'ltable_Artist_Name', 'ltable_Price', 'ltable_Released', 'rtable_Song_Name', - 'rtable_Artist_Name', 'rtable_Price', 'rtable_Released' + 'ltable_Artist_Name', 'ltable_Price', 'ltable_Released', + 'rtable_Song_Name', 'rtable_Artist_Name', 'rtable_Price', + 'rtable_Released' ]) id_attr = '_id' label_attr = 'label' - fields = _make_fields(header, id_attr, label_attr, ['ltable_id', 'rtable_id'], - True, 'nltk', True) + fields = _make_fields(header, id_attr, label_attr, + ['ltable_id', 'rtable_id'], True, 'nltk', True) self.assertEqual(len(fields), 12) counter = {} for tup in fields: @@ -114,7 +118,6 @@ def test_make_fields_1(self): class ProcessTestCases(unittest.TestCase): - def test_process_1(self): vectors_cache_dir = '.cache' if os.path.exists(vectors_cache_dir): @@ -134,17 +137,16 @@ def test_process_1(self): url_base = urljoin('file:', pathname2url(pathdir)) + os.path.sep ft = FastText(filename, url_base=url_base, cache=vectors_cache_dir) - process( - data_dir, - train=train_path, - validation=valid_path, - test=test_path, - id_attr='_id', - left_prefix='ltable_', - right_prefix='rtable_', - cache=cache_file, - embeddings=ft, - embeddings_cache_path='') + process(data_dir, + train=train_path, + validation=valid_path, + test=test_path, + id_attr='_id', + left_prefix='ltable_', + right_prefix='rtable_', + cache=cache_file, + embeddings=ft, + embeddings_cache_path='') if os.path.exists(vectors_cache_dir): shutil.rmtree(vectors_cache_dir) @@ -152,6 +154,7 @@ def test_process_1(self): if os.path.exists(cache_path): os.remove(cache_path) + class ProcessUnlabeledTestCases(unittest.TestCase): def test_process_unlabeled_1(self): vectors_cache_dir = '.cache' @@ -159,7 +162,7 @@ def test_process_unlabeled_1(self): shutil.rmtree(vectors_cache_dir) data_cache_path = os.path.join(test_dir_path, 'test_datasets', - 'cacheddata.pth') + 'cacheddata.pth') if os.path.exists(data_cache_path): os.remove(data_cache_path) @@ -168,26 +171,25 @@ def test_process_unlabeled_1(self): url_base = urljoin('file:', pathname2url(vec_dir)) + os.path.sep ft = FastText(filename, url_base=url_base, cache=vectors_cache_dir) - train, valid, test = process( - path=os.path.join(test_dir_path, 'test_datasets'), - train='test_train.csv', - validation='test_valid.csv', - test='test_test.csv', - id_attr='id', - ignore_columns=('left_id', 'right_id'), - embeddings=ft, - embeddings_cache_path='', - pca=True) + train, valid, test = process(path=os.path.join(test_dir_path, + 'test_datasets'), + train='test_train.csv', + validation='test_valid.csv', + test='test_test.csv', + id_attr='id', + ignore_columns=('left_id', 'right_id'), + embeddings=ft, + embeddings_cache_path='', + pca=True) model_save_path = 'sif_model.pth' model = MatchingModel(attr_summarizer='sif') - model.run_train( - train, - valid, - epochs=1, - batch_size=8, - best_save_path= model_save_path, - pos_neg_ratio=3) + model.run_train(train, + valid, + epochs=1, + batch_size=8, + best_save_path=model_save_path, + pos_neg_ratio=3) test_unlabeled = process_unlabeled( path=os.path.join(test_dir_path, 'test_datasets', 'test_test.csv'),