Skip to content

Commit

Permalink
Added option to ignore soft-clip when undoing transforms.
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Linder committed May 3, 2024
1 parent 13623eb commit a3a4320
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/baskerville/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def targets_prep_strand(targets_df):
return targets_strand_df


def untransform_preds(preds, targets_df, unscale=False):
def untransform_preds(preds, targets_df, unscale=False, unclip=True):
"""Undo the squashing transformations performed for the tasks.
Args:
Expand All @@ -377,9 +377,10 @@ def untransform_preds(preds, targets_df, unscale=False):
preds (np.array): Untransformed predictions LxT.
"""
# clip soft
cs = np.expand_dims(np.array(targets_df.clip_soft), axis=0)
preds_unclip = cs - 1 + (preds - cs + 1) ** 2
preds = np.where(preds > cs, preds_unclip, preds)
if unclip :
cs = np.expand_dims(np.array(targets_df.clip_soft), axis=0)
preds_unclip = cs - 1 + (preds - cs + 1) ** 2
preds = np.where(preds > cs, preds_unclip, preds)

# sqrt
sqrt_mask = np.array([ss.find("_sqrt") != -1 for ss in targets_df.sum_stat])
Expand All @@ -393,7 +394,7 @@ def untransform_preds(preds, targets_df, unscale=False):
return preds


def untransform_preds1(preds, targets_df, unscale=False):
def untransform_preds1(preds, targets_df, unscale=False, unclip=True):
"""Undo the squashing transformations performed for the tasks.
Args:
Expand All @@ -408,9 +409,10 @@ def untransform_preds1(preds, targets_df, unscale=False):
preds = preds / scale

# clip soft
cs = np.expand_dims(np.array(targets_df.clip_soft), axis=0)
preds_unclip = cs + (preds - cs) ** 2
preds = np.where(preds > cs, preds_unclip, preds)
if unclip :
cs = np.expand_dims(np.array(targets_df.clip_soft), axis=0)
preds_unclip = cs + (preds - cs) ** 2
preds = np.where(preds > cs, preds_unclip, preds)

# ** 0.75
sqrt_mask = np.array([ss.find("_sqrt") != -1 for ss in targets_df.sum_stat])
Expand Down

0 comments on commit a3a4320

Please sign in to comment.