diff --git a/src/pprl/app/utils.py b/src/pprl/app/utils.py index 83e5e43..ca79f42 100644 --- a/src/pprl/app/utils.py +++ b/src/pprl/app/utils.py @@ -125,7 +125,7 @@ def convert_dataframe_to_bf( feature types to be processed as appropriate. other_columns: list[str] Columns to be returned as they appear in the data in addition to - `bf_indices` and `bf_norms`. + `bf_indices`, `bf_norms` and `thresholds`. salt: str Cryptographic salt to add to tokens before hashing. @@ -134,10 +134,7 @@ def convert_dataframe_to_bf( output: pandas.DataFrame Data frame of bloom-filtered data. """ - if other_columns is None: - other_columns = [] - output_columns = other_columns + ["bf_indices", "bf_norms", "thresholds"] NGRAMS = [1, 2, 3, 4] FFARGS = {"name": {"ngram_length": NGRAMS, "use_gen_skip_grams": True}} @@ -156,6 +153,6 @@ def convert_dataframe_to_bf( ) df_bloom_filter = embedder.embed(df, colspec, update_norms=True, update_thresholds=True) - output = df_bloom_filter[output_columns] + output = df_bloom_filter.anonymise(other_columns) return output diff --git a/src/pprl/embedder/embedder.py b/src/pprl/embedder/embedder.py index 8f1e2fc..24771b2 100644 --- a/src/pprl/embedder/embedder.py +++ b/src/pprl/embedder/embedder.py @@ -157,6 +157,29 @@ def update_norms(self) -> "EmbeddedDataFrame": return self + def anonymise(self, keep: None | list = None) -> "EmbeddedDataFrame": + """Remove raw data from embedded dataframe. + + Remove all columns from the embedded dataframe expect columns listed + in keep and `bf_indices`, `bf_norms` and `thresholds`. + + Returns + ------- + keep: list[str] + Columns to be returned as they appear in the data in addition to + `bf_indices`, `bf_norms` and `thresholds` if they are present in + the data. + """ + + if keep is None: + keep = [] + + output_columns = keep + ["bf_indices", "bf_norms", "thresholds"] + output_columns = [column for column in self.columns if column in output_columns] + # remove duplicate column names + output_columns = list(dict.fromkeys(output_columns)) + return self[output_columns] + class SimilarityArray(np.ndarray): """Augmented NumPy array of similarity scores with extra attributes. diff --git a/test/embedder/test_embedder.py b/test/embedder/test_embedder.py index 5684829..f00b4f3 100644 --- a/test/embedder/test_embedder.py +++ b/test/embedder/test_embedder.py @@ -75,6 +75,32 @@ def test_update_norms(posdef_matrix): assert bf_norms1 == bf_norms2 +def test_anonymise(): + """Tests EmbeddedDataFrame.anonymise. + + Test that the columns in the keep list are returned in their + original order in addition to the bf_indices column. + """ + + matrix = np.eye(5) + df = pd.DataFrame( + dict( + idx=[1], + firstname=["Fred"], + age=[43], + lastname=["Hogan O'Malley"], + bf_indices=[45], + ) + ) + embedder_mock = mock.Mock(Embedder) + embedder_mock.scm_matrix = matrix + embedder_mock.checksum = "1234" + edf = EmbeddedDataFrame(df, embedder_mock, update_norms=False, update_thresholds=False) + + edf_anonymised = edf.anonymise(keep=["age", "lastname", "idx", "age"]) + assert list(edf_anonymised.columns) == ["idx", "age", "lastname", "bf_indices"] + + def test_embed_colspec(): """Check that only the name column in the colspec is processed."""