Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement spectrum_output module for v4 #220

Merged
merged 15 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions docs/source/api/ms2pip.constants.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
ms2pip.constants
****************

.. py:data:: ms2pip.constants.SUPPORTED_OUTPUT_FORMATS
:type: list

Supported file formats for spectrum output

.. py:data:: ms2pip.constants.MODELS
:type: dict

Expand Down
5 changes: 5 additions & 0 deletions docs/source/api/ms2pip.spectrum-output.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@ ms2pip.spectrum_output

.. automodule:: ms2pip.spectrum_output
:members:

.. py:data:: ms2pip.spectrum_output.SUPPORTED_FORMATS
:type: dict

Supported file formats and respective :py:class:`_Writer` class for spectrum output.
76 changes: 33 additions & 43 deletions ms2pip/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,12 @@
from werkzeug.utils import secure_filename

import ms2pip.core
from ms2pip import __version__
from ms2pip import __version__, exceptions
from ms2pip._utils.cli import build_credits, build_prediction_table
from ms2pip.constants import MODELS, SUPPORTED_OUTPUT_FORMATS
from ms2pip.exceptions import (
InvalidXGBoostModelError,
UnknownModelError,
UnknownOutputFormatError,
UnresolvableModificationError,
)
from ms2pip.result import correlations_to_csv, results_to_csv
from ms2pip.spectrum_output import write_single_spectrum_csv, write_single_spectrum_png
from ms2pip.constants import MODELS
from ms2pip.plot import spectrum_to_png
from ms2pip.result import write_correlations
from ms2pip.spectrum_output import SUPPORTED_FORMATS, write_spectra

console = Console()
logger = logging.getLogger(__name__)
Expand All @@ -44,7 +39,8 @@ def _infer_output_name(
if output_name:
return Path(output_name)
else:
return Path(input_filename).with_suffix("")
input__filename = Path(input_filename)
return input__filename.with_name(input__filename.stem + "_predictions").with_suffix("")


@click.group()
Expand All @@ -65,49 +61,47 @@ def cli(*args, **kwargs):
@cli.command(help=ms2pip.core.predict_single.__doc__)
@click.argument("peptidoform", required=True)
@click.option("--output-name", "-o", type=str)
@click.option("--output-format", "-f", type=click.Choice(SUPPORTED_FORMATS), default="tsv")
@click.option("--model", type=click.Choice(MODELS), default="HCD")
@click.option("--model-dir")
@click.option("--plot", "-p", is_flag=True)
def predict_single(*args, **kwargs):
# Parse arguments
output_name = kwargs.pop("output_name")
output_format = kwargs.pop("output_format")
plot = kwargs.pop("plot")
if not output_name:
output_name = "ms2pip_prediction_" + secure_filename(kwargs["peptidoform"]) + ".csv"
output_name = "ms2pip_prediction_" + secure_filename(kwargs["peptidoform"])

# Predict spectrum
result = ms2pip.core.predict_single(*args, **kwargs)
predicted_spectrum, _ = result.as_spectra()

# Write output
console.print(build_prediction_table(predicted_spectrum))
write_single_spectrum_csv(predicted_spectrum, output_name)
write_spectra(output_name, [result], output_format)
if plot:
write_single_spectrum_png(predicted_spectrum, output_name)
spectrum_to_png(predicted_spectrum, output_name)


@cli.command(help=ms2pip.core.predict_batch.__doc__)
@click.argument("psms", required=True)
@click.option("--output-name", "-o", type=str)
@click.option("--output-format", "-f", type=click.Choice(SUPPORTED_OUTPUT_FORMATS))
@click.option("--output-format", "-f", type=click.Choice(SUPPORTED_FORMATS), default="tsv")
@click.option("--add-retention-time", "-r", is_flag=True)
@click.option("--model", type=click.Choice(MODELS), default="HCD")
@click.option("--model-dir")
@click.option("--processes", "-n", type=int)
def predict_batch(*args, **kwargs):
# Parse arguments
output_name = kwargs.pop("output_name")
output_format = kwargs.pop("output_format") # noqa F841 TODO
output_name = _infer_output_name(kwargs["psms"], output_name)
output_format = kwargs.pop("output_format")
output_name = _infer_output_name(kwargs["psms"], kwargs.pop("output_name"))

# Run
predictions = ms2pip.core.predict_batch(*args, **kwargs)

# Write output
output_name_csv = output_name.with_name(output_name.stem + "_predictions").with_suffix(".csv")
logger.info(f"Writing output to {output_name_csv}")
results_to_csv(predictions, output_name_csv)
# TODO: add support for other output formats
write_spectra(output_name, predictions, output_format)


@cli.command(help=ms2pip.core.predict_library.__doc__)
Expand All @@ -129,24 +123,22 @@ def predict_library(*args, **kwargs):
@click.option("--processes", "-n", type=int)
def correlate(*args, **kwargs):
# Parse arguments
output_name = kwargs.pop("output_name")
output_name = _infer_output_name(kwargs["psms"], output_name)
output_name = _infer_output_name(kwargs["psms"], kwargs.pop("output_name"))

# Run
results = ms2pip.core.correlate(*args, **kwargs)

# Write output
output_name_int = output_name.with_name(output_name.stem + "_predictions").with_suffix(".csv")
logger.info(f"Writing intensities to {output_name_int}")
results_to_csv(results, output_name_int)
# TODO: add support for other output formats
# Write intensities
logger.info(f"Writing intensities to {output_name.with_suffix('.tsv')}")
write_spectra(output_name, results, "tsv")

# Write correlations
if kwargs["compute_correlations"]:
output_name_corr = output_name.with_name(output_name.stem + "_correlations")
output_name_corr = output_name_corr.with_suffix(".csv")
output_name_corr = output_name.with_name(output_name.stem + "_correlations").with_suffix(
".tsv"
)
logger.info(f"Writing correlations to {output_name_corr}")
correlations_to_csv(results, output_name_corr)
write_correlations(results, output_name_corr)


@cli.command(help=ms2pip.core.get_training_data.__doc__)
Expand Down Expand Up @@ -188,32 +180,30 @@ def annotate_spectra(*args, **kwargs):
# Run
results = ms2pip.core.annotate_spectra(*args, **kwargs)

# Write output
output_name_int = output_name.with_name(output_name.stem + "_observations").with_suffix(".csv")
logger.info(f"Writing intensities to {output_name_int}")
results_to_csv(results, output_name_int)
# Write intensities
output_name_int = output_name.with_name(output_name.stem + "_observations").with_suffix()
logger.info(f"Writing intensities to {output_name_int.with_suffix('.tsv')}")
write_spectra(output_name, results, "tsv")


def main():
try:
cli()
except UnresolvableModificationError as e:
except exceptions.UnresolvableModificationError as e:
logger.critical(
"Unresolvable modification: `%s`. See "
"https://ms2pip.readthedocs.io/en/stable/usage/#amino-acid-modifications "
"for more info.",
e,
)
sys.exit(1)
except UnknownOutputFormatError as o:
logger.critical(
f"Unknown output format: `{o}` (supported formats: `{SUPPORTED_OUTPUT_FORMATS}`)"
)
except exceptions.UnknownOutputFormatError as o:
logger.critical(f"Unknown output format: `{o}` (supported formats: `{SUPPORTED_FORMATS}`)")
sys.exit(1)
except UnknownModelError as f:
except exceptions.UnknownModelError as f:
logger.critical(f"Unknown model: `{f}` (supported models: {set(MODELS.keys())})")
sys.exit(1)
except InvalidXGBoostModelError:
except exceptions.InvalidXGBoostModelError:
logger.critical("Could not correctly download XGBoost model\nTry a manual download.")
sys.exit(1)
except Exception:
Expand Down
4 changes: 3 additions & 1 deletion ms2pip/_utils/dlib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Database configuration for EncyclopeDIA DLIB SQLite format."""

import zlib
from pathlib import Path
from typing import Union

import numpy
import sqlalchemy
Expand Down Expand Up @@ -91,7 +93,7 @@ def copy(self):
)


def open_sqlite(filename):
def open_sqlite(filename: Union[str, Path]) -> sqlalchemy.engine.Connection:
engine = sqlalchemy.create_engine(f"sqlite:///{filename}")
metadata.bind = engine
return engine.connect()
6 changes: 0 additions & 6 deletions ms2pip/_utils/psm_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import psm_utils.io.peptide_record
from psm_utils import PSMList

from ms2pip import exceptions

logger = logging.getLogger(__name__)


Expand All @@ -23,10 +21,6 @@ def read_psms(psms: Union[str, Path, PSMList], filetype: Union[str, None]) -> PS
else:
raise TypeError("Invalid type for psms. Should be str, Path, or PSMList.")

# Validate runs and collections
if not len(psm_list.collections) == 1 or not len(psm_list.runs) == 1:
raise exceptions.InvalidInputError("PSMs should be for a single run and collection.")

# Apply fixed modifications if any
psm_list.apply_fixed_modifications()

Expand Down
1 change: 1 addition & 0 deletions ms2pip/_utils/xgb_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def get_predictions_xgb(features, num_ions, model_params, model_dir, processes=1
for ion_type, xgb_model in xgboost_models.items():
# Get predictions from XGBoost model
preds = xgb_model.predict(features)
preds = preds.clip(min=np.log2(0.001)) # Clip negative intensities
xgb_model.__del__()

# Reshape into arrays for each peptide
Expand Down
3 changes: 0 additions & 3 deletions ms2pip/constants.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""Constants and fixed configurations for MS²PIP."""

# Supported output formats
SUPPORTED_OUTPUT_FORMATS = ["csv", "mgf", "msp", "bibliospec", "spectronaut", "dlib"]

# Models and their properties
# id is passed to get_predictions to select model
# ion_types is required to write the ion types in the headers of the result files
Expand Down
14 changes: 11 additions & 3 deletions ms2pip/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
from ms2pip._utils.psm_input import read_psms
from ms2pip._utils.retention_time import RetentionTime
from ms2pip._utils.xgb_models import get_predictions_xgb, validate_requested_xgb_model
from ms2pip.constants import MODELS, SUPPORTED_OUTPUT_FORMATS
from ms2pip.constants import MODELS
from ms2pip.result import ProcessingResult, calculate_correlations
from ms2pip.spectrum_input import read_spectrum_file
from ms2pip.spectrum_output import SUPPORTED_FORMATS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -424,7 +425,7 @@ def _validate_output_formats(self, output_formats: List[str]) -> List[str]:
self.output_formats = ["csv"]
else:
for output_format in output_formats:
if output_format not in SUPPORTED_OUTPUT_FORMATS:
if output_format not in SUPPORTED_FORMATS:
raise exceptions.UnknownOutputFormatError(output_format)
self.output_formats = output_formats

Expand Down Expand Up @@ -544,6 +545,10 @@ def process_spectra(
If only peak annotations should be extracted from the spectrum file

"""
# Validate runs and collections
if not len(psm_list.collections) == 1 or not len(psm_list.runs) == 1:
raise exceptions.InvalidInputError("PSMs should be for a single run and collection.")

args = (
spectrum_file,
vector_file,
Expand Down Expand Up @@ -672,7 +677,10 @@ def _process_peptidoform(
MODELS[model]["peaks_version"],
30.0, # TODO: Remove CE feature
)
predictions = {i: np.array(p, dtype=np.float32) for i, p in zip(ion_types, predictions)}
predictions = {
i: np.array(p, dtype=np.float32).clip(min=np.log2(0.001)) # Clip negative intensities
for i, p in zip(ion_types, predictions)
}
feature_vectors = None

return ProcessingResult(
Expand Down
23 changes: 23 additions & 0 deletions ms2pip/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path
from typing import Union

from ms2pip.spectrum import Spectrum

try:
import matplotlib.pyplot as plt
import spectrum_utils.plot as sup

_can_plot = True
except ImportError:
_can_plot = False


def spectrum_to_png(spectrum: Spectrum, filepath: Union[str, Path]):
"""Plot a single spectrum and write to a PNG file."""
if not _can_plot:
raise ImportError("Matplotlib and spectrum_utils are required to plot spectra.")
ax = plt.gca()
ax.set_title("MS²PIP prediction for " + str(spectrum.peptidoform))
sup.spectrum(spectrum.to_spectrum_utils(), ax=ax)
plt.savefig(Path(filepath).with_suffix(".png"))
plt.close()
40 changes: 4 additions & 36 deletions ms2pip/result.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Definition and handling of MS²PIP results."""

from __future__ import annotations

import csv
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from psm_utils import PSM
from pydantic import ConfigDict, BaseModel
from pydantic import BaseModel, ConfigDict

try:
import spectrum_utils.plot as sup
Expand Down Expand Up @@ -115,44 +116,11 @@ def calculate_correlations(results: List[ProcessingResult]) -> None:
result.correlation = np.corrcoef(pred_int, obs_int)[0][1]


def results_to_csv(results: List["ProcessingResult"], output_file: str) -> None:
"""Write processing results to CSV file."""
with open(output_file, "wt") as f:
fieldnames = [
"psm_index",
"ion_type",
"ion_number",
"mz",
"predicted",
"observed",
]
writer = csv.DictWriter(f, fieldnames=fieldnames, lineterminator="\n")
writer.writeheader()
for result in results:
if result.theoretical_mz is not None:
for ion_type in result.theoretical_mz:
for i in range(len(result.theoretical_mz[ion_type])):
writer.writerow(
{
"psm_index": result.psm_index,
"ion_type": ion_type,
"ion_number": i + 1,
"mz": "{:.6g}".format(result.theoretical_mz[ion_type][i]),
"predicted": "{:.6g}".format(
result.predicted_intensity[ion_type][i]
) if result.predicted_intensity else None,
"observed": "{:.6g}".format(result.observed_intensity[ion_type][i])
if result.observed_intensity
else None,
}
)


def correlations_to_csv(results: List["ProcessingResult"], output_file: str) -> None:
def write_correlations(results: List["ProcessingResult"], output_file: str) -> None:
"""Write correlations to CSV file."""
with open(output_file, "wt") as f:
fieldnames = ["psm_index", "correlation"]
writer = csv.DictWriter(f, fieldnames=fieldnames, lineterminator="\n")
writer = csv.DictWriter(f, fieldnames=fieldnames, delimiter="\t", lineterminator="\n")
writer.writeheader()
for result in results:
writer.writerow({"psm_index": result.psm_index, "correlation": result.correlation})
6 changes: 3 additions & 3 deletions ms2pip/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def __repr__(self) -> str:
@model_validator(mode="after")
@classmethod
def check_array_lengths(cls, data: dict):
if len(data["mz"]) != len(data["intensity"]):
if len(data.mz) != len(data.intensity):
raise ValueError("Array lengths do not match.")
if data["annotations"] is not None:
if len(data["annotations"]) != len(data["intensity"]):
if data.annotations is not None:
if len(data.annotations) != len(data.intensity):
raise ValueError("Array lengths do not match.")
return data

Expand Down
Loading
Loading