Skip to content

Commit

Permalink
improve calibration speed
Browse files Browse the repository at this point in the history
  • Loading branch information
rodvrees committed Apr 15, 2024
1 parent d7bb1b6 commit da122ac
Showing 1 changed file with 20 additions and 49 deletions.
69 changes: 20 additions & 49 deletions im2deep/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,25 +59,13 @@ def get_ccs_shift(
"""
LOGGER.debug(f"Using charge state {use_charge_state} for CCS shift calculation.")

tmp_df = cal_df.copy(deep=True)
tmp_ref_df = reference_dataset.copy(deep=True)

tmp_df["sequence"] = tmp_df["peptidoform"].apply(lambda x: x.proforma.split("\\")[0])
tmp_df["charge"] = tmp_df["peptidoform"].apply(lambda x: x.precursor_charge)
tmp_ref_df["sequence"] = tmp_ref_df["peptidoform"].apply(
lambda x: Peptidoform(x).proforma.split("\\")[0]
)
tmp_ref_df["charge"] = tmp_ref_df["peptidoform"].apply(
lambda x: Peptidoform(x).precursor_charge
)

reference_tmp = tmp_ref_df[tmp_ref_df["charge"] == use_charge_state]
df_tmp = tmp_df[tmp_df["charge"] == use_charge_state]
reference_tmp = reference_dataset[reference_dataset["charge"] == use_charge_state]
df_tmp = cal_df[cal_df["charge"] == use_charge_state]
both = pd.merge(
left=reference_tmp,
right=df_tmp,
right_on=["sequence", "charge"],
left_on=["sequence", "charge"],
left_on=["peptidoform", "charge"],
how="inner",
suffixes=("_ref", "_data"),
)
Expand All @@ -90,7 +78,7 @@ def get_ccs_shift(

# How much CCS in calibration data is larger than reference CCS, so predictions
# need to be increased by this amount
return 0 if both.shape[0] == 0 else np.mean(both["ccs_observed"] - both["CCS"])
return 0 if both.empty else np.mean(both["ccs_observed"] - both["CCS"])


def get_ccs_shift_per_charge(cal_df: pd.DataFrame, reference_dataset: pd.DataFrame) -> ndarray:
Expand All @@ -111,25 +99,11 @@ def get_ccs_shift_per_charge(cal_df: pd.DataFrame, reference_dataset: pd.DataFra
CCS shift factors per charge state.
"""
tmp_df = cal_df.copy(deep=True)
tmp_ref_df = reference_dataset.copy(deep=True)

tmp_df["sequence"] = tmp_df["peptidoform"].apply(lambda x: x.proforma.split("\\")[0])
tmp_df["charge"] = tmp_df["peptidoform"].apply(lambda x: x.precursor_charge)
tmp_ref_df["sequence"] = tmp_ref_df["peptidoform"].apply(
lambda x: Peptidoform(x).proforma.split("\\")[0]
)
tmp_ref_df["charge"] = tmp_ref_df["peptidoform"].apply(
lambda x: Peptidoform(x).precursor_charge
)

reference_tmp = tmp_ref_df
df_tmp = tmp_df
both = pd.merge(
left=reference_tmp,
right=df_tmp,
left=reference_dataset,
right=cal_df,
right_on=["sequence", "charge"],
left_on=["sequence", "charge"],
left_on=["peptidoform", "charge"],
how="inner",
suffixes=("_ref", "_data"),
)
Expand Down Expand Up @@ -159,7 +133,6 @@ def calculate_ccs_shift(
CCS shift factor.
"""
cal_df["charge"] = cal_df["peptidoform"].apply(lambda x: x.precursor_charge)
cal_df = cal_df[cal_df["charge"] < 7] # predictions do not go higher for IM2Deep

if not per_charge:
Expand Down Expand Up @@ -207,37 +180,35 @@ def linear_calibration(
"""
LOGGER.info("Calibrating CCS values using linear calibration...")
calibration_dataset['sequence'] = calibration_dataset['peptidoform'].apply(lambda x: x.proforma.split("\\")[0])
calibration_dataset['charge'] = calibration_dataset['peptidoform'].apply(lambda x: x.precursor_charge)
# reference_dataset['sequence'] = reference_dataset['peptidoform'].apply(lambda x: x.split('/')[0])
reference_dataset['charge'] = reference_dataset['peptidoform'].apply(lambda x: int(x.split('/')[1]))

if per_charge:
LOGGER.info('Getting general shift factor')
general_shift = calculate_ccs_shift(
calibration_dataset,
reference_dataset,
per_charge=False,
use_charge_state=use_charge_state,
)
LOGGER.info('Getting shift factors per charge state')
shift_factor_dict = calculate_ccs_shift(
calibration_dataset, reference_dataset, per_charge=True
)
for charge in preds_df["charge"].unique():
if charge not in shift_factor_dict:
LOGGER.info(
"No overlapping precursors for charge state {}. Using overall shift factor for precursors with that charge.".format(
charge
)
)
shift_factor_dict[charge] = general_shift
LOGGER.info("Shift factors per charge: {}".format(shift_factor_dict))
preds_df["predicted_ccs"] = preds_df.apply(
lambda x: x["predicted_ccs"] + shift_factor_dict[x["charge"]], axis=1
)

preds_df['shift'] = preds_df['charge'].map(shift_factor_dict).fillna(general_shift)
preds_df['predicted_ccs'] = preds_df['predicted_ccs'] + preds_df['shift']

else:
shift_factor = calculate_ccs_shift(
calibration_dataset,
reference_dataset,
per_charge=False,
use_charge_state=use_charge_state,
)
preds_df["predicted_ccs"] = preds_df.apply(
lambda x: x["predicted_ccs"] + shift_factor, axis=1
)
preds_df['predicted_ccs'] += shift_factor

LOGGER.info("CCS values calibrated.")
return preds_df

0 comments on commit da122ac

Please sign in to comment.