Skip to content
This repository has been archived by the owner on Aug 9, 2023. It is now read-only.

Commit

Permalink
Merge pull request #276 from wellcometrust/feature/clustering-improve…
Browse files Browse the repository at this point in the history
…ments

Feature/clustering improvements
  • Loading branch information
aCampello authored Apr 26, 2021
2 parents b5ef951 + 4003a70 commit e81be34
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 7 deletions.
13 changes: 12 additions & 1 deletion tests/test_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
@pytest.mark.parametrize("reducer,cluster_reduced", [("tsne", True),
("umap", True),
("umap", False)])
def test_full_pipeline(reducer, cluster_reduced):
def test_full_pipeline(reducer, cluster_reduced, tmp_path):
cluster = TextClustering(reducer=reducer, cluster_reduced=cluster_reduced,
embedding_random_state=42,
reducer_random_state=43,
Expand All @@ -23,6 +23,17 @@ def test_full_pipeline(reducer, cluster_reduced):

assert len(cluster.cluster_kws) == len(cluster.cluster_ids) == 6

cluster.save(folder=tmp_path)

cluster_new = TextClustering()
cluster_new.load(folder=tmp_path)

# Asserts all coordinates of the loaded points are equal
assert (cluster_new.embedded_points != cluster.embedded_points).sum() == 0
assert (cluster_new.reduced_points != cluster.reduced_points).sum() == 0
assert cluster_new.reducer_class.__class__ == cluster.reducer_class.__class__
assert cluster_new.clustering_class.__class__ == cluster.clustering_class.__class__


@pytest.mark.parametrize("reducer", ["tsne", "umap"])
def test_parameter_search(reducer):
Expand Down
98 changes: 92 additions & 6 deletions wellcomeml/ml/clustering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict
import logging
import os
import pickle

from wellcomeml.ml import vectorizer
from wellcomeml.logger import logger
Expand Down Expand Up @@ -39,6 +40,7 @@ class TextClustering(object):
cluster_names: Names of the clusters
cluster_kws: Keywords for the clusters (only if embedding=tf-idf)
"""

def __init__(self, embedding='tf-idf', reducer='umap', clustering='dbscan',
cluster_reduced=True, n_kw=10, params={},
embedding_random_state=None, reducer_random_state=None,
Expand Down Expand Up @@ -113,13 +115,21 @@ class that is a sklearn.base.ClusterMixin
'random_state'):
self.clustering_class.random_state = clustering_random_state

self.embedded_points = None
self.reduced_points = None
self.cluster_ids = None
self.cluster_names = None
self.cluster_kws = None
self.kw_dictionary = {}
self.silhouette = None
self.optimise_results = {}

self.embedded_points_filename = 'embedded_points.npy'
self.reduced_points_filename = 'reduced_points.npy'
self.vectorizer_filename = 'vectorizer.pkl'
self.reducer_filename = 'reducer.pkl'
self.clustering_filename = 'clustering.pkl'

def fit(self, X, *_):
"""
Fits all clusters in the pipeline
Expand All @@ -131,22 +141,28 @@ def fit(self, X, *_):
A TextClustering object
"""
self._fit_step(X, step='vectorizer')
self._fit_step(step='reducer')
self._fit_step(step='clustering')
self.fit_step(X, step='vectorizer')
self.fit_step(step='reducer')
self.fit_step(step='clustering')

if self.embedding == 'tf-idf' and self.n_kw:
self._find_keywords(self.embedded_points.toarray(), n_kw=self.n_kw)

return self

def _fit_step(self, X=None, step='vectorizer'):
def fit_step(self, X=None, y=None, step='vectorizer'):
"""Internal function for partial fitting only a certain step"""
if step == 'vectorizer':
self.embedded_points = self.vectorizer.fit_transform(X)
elif step == 'reducer':
self.reduced_points = \
self.reducer_class.fit_transform(self.embedded_points)
if self.embedded_points is None:
raise ValueError(
'You must embed/vectorise the points before reducing dimensionality'
)
if X is None:
X = self.embedded_points

self.reduced_points = self.reducer_class.fit_transform(X=X, y=y)
elif step == 'clustering':
points = (
self.reduced_points if self.cluster_reduced else
Expand Down Expand Up @@ -260,7 +276,9 @@ def optimise(self, X, param_grid, n_cluster_range=None, max_noise=0.2,
# Prunes result to actually optimise under constraints
best_silhouette = 0
best_params = {}

grid.fit(X, y=None)

for params, silhouette, noise, n_clusters in zip(
grid.cv_results_['params'],
grid.cv_results_['mean_test_silhouette'],
Expand Down Expand Up @@ -292,6 +310,74 @@ def optimise(self, X, param_grid, n_cluster_range=None, max_noise=0.2,

return best_params

def save(self, folder, components='all', create_folder=True):
"""
Saves the different steps of the pipeline
Args:
folder(str): path to folder
components(list or 'all'): List of components to save. Options are: 'embbedded_points',
'reduced_points', 'vectorizer', 'reducer', and 'clustering_model'. By default, loads
'all' (you can get all components by listing the class param
TextClustering.components)
"""
if create_folder:
os.makedirs(folder, exist_ok=True)

if components == 'all' or 'embedded_points' in components:
np.save(os.path.join(folder, self.embedded_points_filename), self.embedded_points)

if components == 'all' or 'reduced_points' in components:
np.save(os.path.join(folder, self.reduced_points_filename), self.reduced_points)

if components == 'all' or 'vectorizer' in components:
with open(os.path.join(folder, self.vectorizer_filename), 'wb') as f:
pickle.dump(self.vectorizer, f)

if components == 'all' or 'reducer' in components:
with open(os.path.join(folder, self.reducer_filename), 'wb') as f:
pickle.dump(self.reducer_class, f)

if components == 'all' or 'clustering_model' in components:
with open(os.path.join(folder, self.clustering_filename), 'wb') as f:
pickle.dump(self.clustering_class, f)

def load(self, folder, components='all'):
"""
Loads the different steps of the pipeline
Args:
folder(str): path to folder
components(list or 'all'): List of components to load. Options are: 'embbedded_points',
'reduced_points', 'vectorizer', 'reducer', and 'clustering_model'. By default, loads
'all' (you can get all components by listing the class param
TextClustering.components)
"""

if components == 'all' or 'embedded_points' in components:
self.embedded_points = np.load(os.path.join(folder, self.embedded_points_filename),
allow_pickle=True)
if not self.embedded_points.shape:
self.embedded_points = self.embedded_points[()]

if components == 'all' or 'reduced_points' in components:
self.reduced_points = np.load(os.path.join(folder, self.reduced_points_filename),
allow_pickle=True)

if components == 'all' or 'vectorizer' in components:
with open(os.path.join(folder, self.vectorizer_filename), 'rb') as f:
self.vectorizer = pickle.load(f)

if components == 'all' or 'reducer' in components:
with open(os.path.join(folder, self.reducer_filename), 'rb') as f:
self.reducer_class = pickle.load(f)

if components == 'all' or 'clustering_model' in components:
with open(os.path.join(folder, self.clustering_filename), 'rb') as f:
self.clustering_class = pickle.load(f)

def stability(self):
"""Function to calculate how stable the clusters are"""
raise NotImplementedError
Expand Down

0 comments on commit e81be34

Please sign in to comment.