Skip to content

Commit

Permalink
Implement spectrum output for CLI functions
Browse files Browse the repository at this point in the history
  • Loading branch information
RalfG committed May 8, 2024
1 parent cdb8ca1 commit bc65c93
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 94 deletions.
76 changes: 33 additions & 43 deletions ms2pip/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,12 @@
from werkzeug.utils import secure_filename

import ms2pip.core
from ms2pip import __version__
from ms2pip import __version__, exceptions
from ms2pip._utils.cli import build_credits, build_prediction_table
from ms2pip.constants import MODELS, SUPPORTED_OUTPUT_FORMATS
from ms2pip.exceptions import (
InvalidXGBoostModelError,
UnknownModelError,
UnknownOutputFormatError,
UnresolvableModificationError,
)
from ms2pip.result import correlations_to_csv, results_to_csv
from ms2pip.spectrum_output import write_single_spectrum_csv, write_single_spectrum_png
from ms2pip.constants import MODELS
from ms2pip.plot import spectrum_to_png
from ms2pip.result import write_correlations
from ms2pip.spectrum_output import SUPPORTED_FORMATS, write_spectra

console = Console()
logger = logging.getLogger(__name__)
Expand All @@ -44,7 +39,8 @@ def _infer_output_name(
if output_name:
return Path(output_name)
else:
return Path(input_filename).with_suffix("")
input__filename = Path(input_filename)
return input__filename.with_name(input__filename.stem + "_predictions").with_suffix("")


@click.group()
Expand All @@ -65,49 +61,47 @@ def cli(*args, **kwargs):
@cli.command(help=ms2pip.core.predict_single.__doc__)
@click.argument("peptidoform", required=True)
@click.option("--output-name", "-o", type=str)
@click.option("--output-format", "-f", type=click.Choice(SUPPORTED_FORMATS), default="tsv")
@click.option("--model", type=click.Choice(MODELS), default="HCD")
@click.option("--model-dir")
@click.option("--plot", "-p", is_flag=True)
def predict_single(*args, **kwargs):
# Parse arguments
output_name = kwargs.pop("output_name")
output_format = kwargs.pop("output_format")
plot = kwargs.pop("plot")
if not output_name:
output_name = "ms2pip_prediction_" + secure_filename(kwargs["peptidoform"]) + ".csv"
output_name = "ms2pip_prediction_" + secure_filename(kwargs["peptidoform"])

# Predict spectrum
result = ms2pip.core.predict_single(*args, **kwargs)
predicted_spectrum, _ = result.as_spectra()

# Write output
console.print(build_prediction_table(predicted_spectrum))
write_single_spectrum_csv(predicted_spectrum, output_name)
write_spectra(output_name, [result], output_format)
if plot:
write_single_spectrum_png(predicted_spectrum, output_name)
spectrum_to_png(predicted_spectrum, output_name)


@cli.command(help=ms2pip.core.predict_batch.__doc__)
@click.argument("psms", required=True)
@click.option("--output-name", "-o", type=str)
@click.option("--output-format", "-f", type=click.Choice(SUPPORTED_OUTPUT_FORMATS))
@click.option("--output-format", "-f", type=click.Choice(SUPPORTED_FORMATS), default="tsv")
@click.option("--add-retention-time", "-r", is_flag=True)
@click.option("--model", type=click.Choice(MODELS), default="HCD")
@click.option("--model-dir")
@click.option("--processes", "-n", type=int)
def predict_batch(*args, **kwargs):
# Parse arguments
output_name = kwargs.pop("output_name")
output_format = kwargs.pop("output_format") # noqa F841 TODO
output_name = _infer_output_name(kwargs["psms"], output_name)
output_format = kwargs.pop("output_format")
output_name = _infer_output_name(kwargs["psms"], kwargs.pop("output_name"))

# Run
predictions = ms2pip.core.predict_batch(*args, **kwargs)

# Write output
output_name_csv = output_name.with_name(output_name.stem + "_predictions").with_suffix(".csv")
logger.info(f"Writing output to {output_name_csv}")
results_to_csv(predictions, output_name_csv)
# TODO: add support for other output formats
write_spectra(output_name, predictions, output_format)


@cli.command(help=ms2pip.core.predict_library.__doc__)
Expand All @@ -129,24 +123,22 @@ def predict_library(*args, **kwargs):
@click.option("--processes", "-n", type=int)
def correlate(*args, **kwargs):
# Parse arguments
output_name = kwargs.pop("output_name")
output_name = _infer_output_name(kwargs["psms"], output_name)
output_name = _infer_output_name(kwargs["psms"], kwargs.pop("output_name"))

# Run
results = ms2pip.core.correlate(*args, **kwargs)

# Write output
output_name_int = output_name.with_name(output_name.stem + "_predictions").with_suffix(".csv")
logger.info(f"Writing intensities to {output_name_int}")
results_to_csv(results, output_name_int)
# TODO: add support for other output formats
# Write intensities
logger.info(f"Writing intensities to {output_name.with_suffix('.tsv')}")
write_spectra(output_name, results, "tsv")

# Write correlations
if kwargs["compute_correlations"]:
output_name_corr = output_name.with_name(output_name.stem + "_correlations")
output_name_corr = output_name_corr.with_suffix(".csv")
output_name_corr = output_name.with_name(output_name.stem + "_correlations").with_suffix(
".tsv"
)
logger.info(f"Writing correlations to {output_name_corr}")
correlations_to_csv(results, output_name_corr)
write_correlations(results, output_name_corr)


@cli.command(help=ms2pip.core.get_training_data.__doc__)
Expand Down Expand Up @@ -188,32 +180,30 @@ def annotate_spectra(*args, **kwargs):
# Run
results = ms2pip.core.annotate_spectra(*args, **kwargs)

# Write output
output_name_int = output_name.with_name(output_name.stem + "_observations").with_suffix(".csv")
logger.info(f"Writing intensities to {output_name_int}")
results_to_csv(results, output_name_int)
# Write intensities
output_name_int = output_name.with_name(output_name.stem + "_observations").with_suffix()
logger.info(f"Writing intensities to {output_name_int.with_suffix('.tsv')}")
write_spectra(output_name, results, "tsv")


def main():
try:
cli()
except UnresolvableModificationError as e:
except exceptions.UnresolvableModificationError as e:
logger.critical(
"Unresolvable modification: `%s`. See "
"https://ms2pip.readthedocs.io/en/stable/usage/#amino-acid-modifications "
"for more info.",
e,
)
sys.exit(1)
except UnknownOutputFormatError as o:
logger.critical(
f"Unknown output format: `{o}` (supported formats: `{SUPPORTED_OUTPUT_FORMATS}`)"
)
except exceptions.UnknownOutputFormatError as o:
logger.critical(f"Unknown output format: `{o}` (supported formats: `{SUPPORTED_FORMATS}`)")
sys.exit(1)
except UnknownModelError as f:
except exceptions.UnknownModelError as f:
logger.critical(f"Unknown model: `{f}` (supported models: {set(MODELS.keys())})")
sys.exit(1)
except InvalidXGBoostModelError:
except exceptions.InvalidXGBoostModelError:
logger.critical("Could not correctly download XGBoost model\nTry a manual download.")
sys.exit(1)
except Exception:
Expand Down
40 changes: 4 additions & 36 deletions ms2pip/result.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Definition and handling of MS²PIP results."""

from __future__ import annotations

import csv
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from psm_utils import PSM
from pydantic import ConfigDict, BaseModel
from pydantic import BaseModel, ConfigDict

try:
import spectrum_utils.plot as sup
Expand Down Expand Up @@ -115,44 +116,11 @@ def calculate_correlations(results: List[ProcessingResult]) -> None:
result.correlation = np.corrcoef(pred_int, obs_int)[0][1]


def results_to_csv(results: List["ProcessingResult"], output_file: str) -> None:
"""Write processing results to CSV file."""
with open(output_file, "wt") as f:
fieldnames = [
"psm_index",
"ion_type",
"ion_number",
"mz",
"predicted",
"observed",
]
writer = csv.DictWriter(f, fieldnames=fieldnames, lineterminator="\n")
writer.writeheader()
for result in results:
if result.theoretical_mz is not None:
for ion_type in result.theoretical_mz:
for i in range(len(result.theoretical_mz[ion_type])):
writer.writerow(
{
"psm_index": result.psm_index,
"ion_type": ion_type,
"ion_number": i + 1,
"mz": "{:.6g}".format(result.theoretical_mz[ion_type][i]),
"predicted": "{:.6g}".format(
result.predicted_intensity[ion_type][i]
) if result.predicted_intensity else None,
"observed": "{:.6g}".format(result.observed_intensity[ion_type][i])
if result.observed_intensity
else None,
}
)


def correlations_to_csv(results: List["ProcessingResult"], output_file: str) -> None:
def write_correlations(results: List["ProcessingResult"], output_file: str) -> None:
"""Write correlations to CSV file."""
with open(output_file, "wt") as f:
fieldnames = ["psm_index", "correlation"]
writer = csv.DictWriter(f, fieldnames=fieldnames, lineterminator="\n")
writer = csv.DictWriter(f, fieldnames=fieldnames, delimiter="\t", lineterminator="\n")
writer.writeheader()
for result in results:
writer.writerow({"psm_index": result.psm_index, "correlation": result.correlation})
34 changes: 19 additions & 15 deletions ms2pip/spectrum_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

import csv
import itertools
import logging
import re
import warnings
from abc import ABC, abstractmethod
Expand All @@ -57,11 +58,13 @@
from ms2pip._utils import dlib
from ms2pip.result import ProcessingResult

LOGGER = logging.getLogger(__name__)


def write_spectra(
filename: Union[str, Path],
processing_results: List[ProcessingResult],
file_format: str,
file_format: str = "tsv",
write_mode: str = "w",
):
"""
Expand All @@ -80,13 +83,14 @@ def write_spectra(
"""
with SUPPORTED_FORMATS[file_format](filename, write_mode) as writer:
LOGGER.info(f"Writing to {writer.filename}")
writer.write(processing_results)


class _Writer(ABC):
"""Abstract base class for writing spectrum files."""

suffix = ".txt"
suffix = ""

def __init__(self, filename: Union[str, Path], write_mode: str = "w"):
self.filename = Path(filename).with_suffix(self.suffix)
Expand Down Expand Up @@ -182,11 +186,11 @@ def _write_row(result: ProcessingResult, ion_type: str, ion_index: int):
"psm_index": result.psm_index,
"ion_type": ion_type,
"ion_number": ion_index + 1,
"mz": "{:.10g}".format(result.theoretical_mz[ion_type][ion_index]),
"predicted": "{:.10g}".format(result.predicted_intensity[ion_type][ion_index])
"mz": "{:.8f}".format(result.theoretical_mz[ion_type][ion_index]),
"predicted": "{:.8f}".format(result.predicted_intensity[ion_type][ion_index])
if result.predicted_intensity
else None,
"observed": "{:.10g}".format(result.observed_intensity[ion_type][ion_index])
"observed": "{:.8f}".format(result.observed_intensity[ion_type][ion_index])
if result.observed_intensity
else None,
"rt": result.psm.retention_time if result.psm.retention_time else None,
Expand Down Expand Up @@ -219,7 +223,7 @@ def _write_result(self, result: ProcessingResult):

# Peaks
lines.extend(
f"{mz:.10g}\t{intensity:.10g}\t{annotation}/0.0" for mz, intensity, annotation in peaks
f"{mz:.8f}\t{intensity:.8f}\t{annotation}/0.0" for mz, intensity, annotation in peaks
)

# Write to file
Expand Down Expand Up @@ -259,7 +263,7 @@ def _format_single_modification(
if not mods:
return "Mods=0"
else:
return f"Mods={len(mods)}/{'/'.join(sorted(mods))}"
return f"Mods={len(mods)}/{'/'.join(mods)}"

@staticmethod
def _format_parent_mass(peptidoform: Peptidoform) -> str:
Expand Down Expand Up @@ -332,11 +336,11 @@ def _write_result(self, result: ProcessingResult):
]

# Peaks
lines.extend(f"{mz:.10g} {intensity:.10g}" for mz, intensity in peaks)
lines.extend(f"{mz:.8f} {intensity:.8f}" for mz, intensity in peaks)

# Write to file
self._file_object.writelines(line + "\n" for line in lines if line)
self._file_object.write("END IONS\n")
self._file_object.write("END IONS\n\n")


class Spectronaut(_Writer):
Expand Down Expand Up @@ -385,9 +389,9 @@ def _process_psm(psm: PSM) -> Dict[str, Any]:
"ModifiedPeptide": _peptidoform_str_without_charge(psm.peptidoform),
"StrippedPeptide": psm.peptidoform.sequence,
"PrecursorCharge": psm.get_precursor_charge(),
"PrecursorMz": f"{psm.peptidoform.theoretical_mz:.10g}",
"IonMobility": f"{psm.ion_mobility:.10g}" if psm.ion_mobility else None,
"iRT": f"{psm.retention_time:.10g}" if psm.retention_time else None,
"PrecursorMz": f"{psm.peptidoform.theoretical_mz:.8f}",
"IonMobility": f"{psm.ion_mobility:.8f}" if psm.ion_mobility else None,
"iRT": f"{psm.retention_time:.8f}" if psm.retention_time else None,
"ProteinId": "".join(psm.protein_list) if psm.protein_list else None,
}

Expand All @@ -411,8 +415,8 @@ def _yield_fragment_info(result: ProcessingResult) -> Generator[Dict[str, Any],
zip(intensities[ion_type], result.theoretical_mz[ion_type])
):
yield {
"RelativeFragmentIntensity": f"{intensity:.10g}",
"FragmentMz": f"{mz:.10g}",
"RelativeFragmentIntensity": f"{intensity:.8f}",
"FragmentMz": f"{mz:.8f}",
"FragmentType": fragment_type,
"FragmentNumber": ion_index + 1,
"FragmentCharge": fragment_charge,
Expand Down Expand Up @@ -567,7 +571,7 @@ def _write_result_to_ms2(
]

# Peaks
lines.extend(f"{mz:.10g}\t{intensity:.10g}" for mz, intensity in peaks)
lines.extend(f"{mz:.8f}\t{intensity:.8f}" for mz, intensity in peaks)

# Write to file
self._ms2_file_object.writelines(line + "\n" for line in lines)
Expand Down

0 comments on commit bc65c93

Please sign in to comment.