diff --git a/q2_feature_table/_normalize.py b/q2_feature_table/_normalize.py index d9c2935..fcde136 100644 --- a/q2_feature_table/_normalize.py +++ b/q2_feature_table/_normalize.py @@ -8,6 +8,12 @@ import biom +import os + +import pandas as pd +from q2_types.feature_data import SequenceCharacteristicsDirectoryFormat +from rnanorm import CPM, CTF, CUF, FPKM, TMM, TPM, UQ + def rarefy(table: biom.Table, sampling_depth: int, with_replacement: bool = False) -> biom.Table: @@ -23,3 +29,90 @@ def rarefy(table: biom.Table, sampling_depth: int, 'shallow enough sampling depth.') return table + + +def normalize( + table: pd.DataFrame, + method: str, + m_trim: float = None, + a_trim: float = None, + gene_length: SequenceCharacteristicsDirectoryFormat = None, +) -> pd.DataFrame: + # Validate parameter combinations and set trim parameters + m_trim, a_trim = _validate_parameters( + method, m_trim, a_trim, gene_length) + + # Process gene_lengths input and define methods that need gene_lengths + # input + if method in ["tpm", "fpkm"]: + lengths = _convert_lengths(table, gene_length) + + methods = { + "tpm": TPM(gene_lengths=lengths), + "fpkm": FPKM(gene_lengths=lengths), + } + + # Define remaining methods that don't need gene_lengths input + else: + methods = { + "tmm": TMM(m_trim=m_trim, a_trim=a_trim), + "ctf": CTF(m_trim=m_trim, a_trim=a_trim), + "uq": UQ(), + "cuf": CUF(), + "cpm": CPM(), + } + + # Run normalization method on frequency table + normalized = methods[method].set_output( + transform="pandas").fit_transform(table) + normalized.index.name = "sample_id" + + return normalized + + +def _validate_parameters(method, m_trim, a_trim, gene_length): + # Raise Error if gene-length is missing when using methods TPM or FPKM + if method in ["tpm", "fpkm"] and not gene_length: + raise ValueError("gene-length input is missing.") + + # Raise Error if gene-length is given when using methods TMM, UQ, CUF, CPM or CTF + if method in ["tmm", "uq", "cuf", "ctf", "cpm"] and gene_length: + raise ValueError( + "gene-length input can only be used with FPKM and TPM methods." + ) + + # Raise Error if m_trim or a_trim are given when not using methods TMM or CTF + if (method not in ["tmm", "ctf"]) and (m_trim is not None or a_trim is not None): + raise ValueError( + "Parameters m-trim and a-trim can only be used with methods TMM and CTF." + ) + + # Set m_trim and a_trim to their default values for methods TMM and CTF + if method in ["tmm", "ctf"]: + m_trim = 0.3 if m_trim is None else m_trim + a_trim = 0.05 if a_trim is None else a_trim + + return m_trim, a_trim + + +def _convert_lengths(table, gene_length): + # Read in table from sequence_characteristics.tsv as a pd.Series + lengths = pd.read_csv( + os.path.join(gene_length.path, "sequence_characteristics.tsv"), + sep="\t", + header=None, + names=["index", "values"], + index_col="index", + skiprows=1, + ).squeeze("columns") + + # Check if all gene IDs that are present in the table are also present in + # the lengths + if not set(table.columns).issubset(set(lengths.index)): + only_in_counts = set(table.columns) - set(lengths.index) + raise ValueError( + f"There are genes present in the FeatureTable that are not present " + f"in the gene-length input. Missing lengths for genes: " + f"{only_in_counts}" + ) + return lengths diff --git a/q2_feature_table/tests/test_normalize.py b/q2_feature_table/tests/test_normalize.py index d45ca94..d2d1726 100644 --- a/q2_feature_table/tests/test_normalize.py +++ b/q2_feature_table/tests/test_normalize.py @@ -5,14 +5,21 @@ # # The full license is in the file LICENSE, distributed with this software. # ---------------------------------------------------------------------------- - +import os +import shutil from unittest import TestCase, main +from unittest.mock import MagicMock, patch import numpy as np import numpy.testing as npt +import pandas as pd from biom.table import Table +from pandas._testing import assert_series_equal +from q2_types.feature_data import SequenceCharacteristicsDirectoryFormat from q2_feature_table import rarefy +from q2_feature_table._normalize import _validate_parameters, _convert_lengths, \ + normalize class RarefyTests(TestCase): @@ -53,5 +60,99 @@ def test_rarefy_depth_error(self): rarefy(t, 50) +class NormalizeTests(TestCase): + + @classmethod + def setUpClass(cls): + cls.lengths = pd.Series( + { + "ARO1": 1356.0, + "ARO2": 1173.0, + }, + name="values", + ) + cls.lengths.index.name = "index" + cls.table = pd.DataFrame({ + 'ID': ['sample1', 'sample2'], + 'ARO1': [2.0, 2.0], + 'ARO2': [0.0, 0.0] + }).set_index('ID') + + def test_validate_parameters_uq_with_m_a_trim(self): + # Test Error raised if gene-length is given with UQ method + with self.assertRaisesRegex( + ValueError, + "Parameters m-trim and a-trim can only " + "be used with methods TMM and CTF.", + ): + _validate_parameters("uq", 0.2, 0.05, None) + + def test_validate_parameters_tpm_missing_gene_length(self): + # Test Error raised if gene-length is missing with TPM method + with self.assertRaisesRegex(ValueError, "gene-length input is missing."): + _validate_parameters("tpm", None, None, None) + + def test_validate_parameters_tmm_gene_length(self): + # Test Error raised if gene-length is given with TMM method + with self.assertRaisesRegex( + ValueError, + "gene-length input can only be used with FPKM and TPM methods." + ): + _validate_parameters("tmm", None, None, gene_length=MagicMock()) + + def test_validate_parameters_default_m_a_trim(self): + # Test if m_trim and a_trim get set to default values if None + m_trim, a_trim = _validate_parameters("tmm", None, None, None) + self.assertEqual(m_trim, 0.3) + self.assertEqual(a_trim, 0.05) + + def test_validate_parameters_m_a_trim(self): + # Test if m_trim and a_trim are not modified if not None + m_trim, a_trim = _validate_parameters("tmm", 0.1, 0.06, None) + self.assertEqual(m_trim, 0.1) + self.assertEqual(a_trim, 0.06) + + def test_convert_lengths_gene_length(self): + # Test _convert_lengths + gene_length = SequenceCharacteristicsDirectoryFormat() + with open(os.path.join(str(gene_length), "sequence_characteristics.tsv"), + 'w') as file: + file.write("id\tlength\nARO1\t1356.0\nARO2\t1173.0") + + obs = _convert_lengths(self.table, gene_length=gene_length) + assert_series_equal(obs, self.lengths) + + def test_convert_lengths_short_gene_length(self): + # Test Error raised if gene-length is missing genes + gene_length = SequenceCharacteristicsDirectoryFormat() + with open(os.path.join(str(gene_length), "sequence_characteristics.tsv"), 'w') as file: + file.write("id\tlength\nARO1\t1356.0") + with self.assertRaisesRegex( + ValueError, + "There are genes present in the FeatureTable that are not present " + "in the gene-length input. Missing lengths for genes: " + "{'ARO2'}", + ): + _convert_lengths(self.table, gene_length=gene_length) + + @patch("q2_feature_table._normalize.TPM") + def test_tpm_fpkm_with_valid_inputs(self, mock_tpm): + # Test valid inputs for TPM method + gene_length = SequenceCharacteristicsDirectoryFormat() + with open(os.path.join(str(gene_length), "sequence_characteristics.tsv"), + 'w') as file: + file.write("id\tlength\nARO1\t1356.0\nARO2\t1173.0") + normalize(table=self.table, gene_length=gene_length, method="tpm") + + @patch("q2_feature_table._normalize.TMM") + def test_tmm_uq_cuf_ctf_with_valid_inputs(self, mock_tmm): + # Test valid inputs for TMM method + gene_length = SequenceCharacteristicsDirectoryFormat() + with open(os.path.join(str(gene_length), "sequence_characteristics.tsv"), + 'w') as file: + file.write("id\tlength\nARO1\t1356.0\nARO2\t1173.0") + normalize(table=self.table, method="tmm", a_trim=0.06, m_trim=0.4) + + if __name__ == "__main__": main()