Skip to content

Commit

Permalink
reduce-mean
Browse files Browse the repository at this point in the history
  • Loading branch information
hy395 committed Oct 23, 2024
1 parent 9559c76 commit 7dd4839
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 5 deletions.
27 changes: 26 additions & 1 deletion src/baskerville/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
mode: str = "eval",
tfr_pattern: str = None,
targets_slice_file: str = None,
reduce_mean = False,
):
self.data_dir = data_dir
self.split_label = split_label
Expand All @@ -69,6 +70,7 @@ def __init__(
self.seq_length_crop = seq_length_crop
self.mode = mode
self.tfr_pattern = tfr_pattern
self.reduce_mean = reduce_mean

# read data parameters
data_stats_file = "%s/statistics.json" % self.data_dir
Expand Down Expand Up @@ -147,11 +149,13 @@ def parse_proto(example_protos):
targets = tf.cast(targets, tf.float32)
if self.targets_slice is not None:
targets = targets[:, self.targets_slice]
if self.reduce_mean:
targets = target_reduce_mean(targets)

return sequence, targets

return parse_proto

def make_dataset(self, cycle_length=4):
"""Make tf.data.Dataset w/ transformations."""

Expand Down Expand Up @@ -423,3 +427,24 @@ def untransform_preds1(preds, targets_df, unscale=False, unclip=True):
preds = preds * scale

return preds

def target_reduce_mean(tensor):
"""
This function computes the row-wise mean of a 2D tensor,
subtracts the mean from each element in the respective row,
and appends the mean as an additional column.
Args:
tensor (tf.Tensor): Input tensor of shape (target_length, num_targets)
Returns:
tf.Tensor: Output tensor of shape (target_length, num_targets+1),
where the last column is the row-wise mean.
"""

mean_tensor = tf.reduce_mean(tensor, axis=1) # Shape (target_length,)
mean_tensor_expanded = tf.reshape(mean_tensor, (-1, 1)) # Shape (target_length, 1)
difference_tensor = tensor - mean_tensor_expanded # Shape (target_length, num_targets)
result_tensor = tf.concat([difference_tensor, mean_tensor_expanded], axis=1) # Shape (target_length, num_targets+1)

return result_tensor
33 changes: 31 additions & 2 deletions src/baskerville/scripts/hound_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@
"""


# method for inserting adapter for transfer learning
def merge_mean_task(model_0):
outputs = model_0.layers[-1].output
task_other = outputs[:,:,:-1]
task_mean = tf.expand_dims(outputs[:,:,-1], axis=2)
new_task = task_other + task_mean
new_outputs = tf.keras.layers.ReLU()(new_task)
new_model = tf.keras.Model(inputs=model_0.layers[0].input, outputs=new_outputs)
return new_model

def main():
parser = argparse.ArgumentParser(description="Evaluate a trained model.")
parser.add_argument(
Expand Down Expand Up @@ -133,9 +143,24 @@ def main():
params_model = params["model"]
params_train = params["train"]

# add task if reduce mean = True
params_transfer = params["transfer"]
transfer_reduce_mean = params_transfer.get("reduce_mean", False)
if transfer_reduce_mean:
params_model['head_human']['units'] += 1

# set strand pairs
if "strand_pair" in targets_df.columns:
params_model["strand_pair"] = [np.array(targets_df.strand_pair)]
tmp = np.array(targets_df.strand_pair)

# add additional targets when reduce_mean
if transfer_reduce_mean:
if all(tmp == targets_df.index): # unstranded
tmp = np.append(tmp, len(tmp))
else: # stranded
tmp = np.append(tmp, [len(tmp)+1, len(tmp)])

params_model["strand_pair"] = [tmp]

# construct eval data
eval_data = dataset.SeqDataset(
Expand All @@ -160,14 +185,18 @@ def main():
seqnn_model.restore(args.model_file, args.head_i)

seqnn_model.build_ensemble(args.rc, args.shifts)

if transfer_reduce_mean:
seqnn_model.ensemble = merge_mean_task(seqnn_model.ensemble)

#######################################################
# evaluate
loss_label = params_train.get("loss", "poisson").lower()
spec_weight = params_train.get("spec_weight", 1)
loss_fn = trainer.parse_loss(loss_label, spec_weight=spec_weight)

# evaluate
print(seqnn_model.ensemble.output_shape[-1])

test_loss, test_metric1, test_metric2 = seqnn_model.evaluate(
eval_data, loss_label=loss_label, loss_fn=loss_fn
)
Expand Down
21 changes: 19 additions & 2 deletions src/baskerville/scripts/hound_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def main():
transfer_conv_rank = params_transfer.get("conv_latent", 4)
transfer_lora_alpha = params_transfer.get("lora_alpha", 16)
transfer_locon_alpha = params_transfer.get("locon_alpha", 1)

transfer_reduce_mean = params_transfer.get("reduce_mean", False)

if transfer_mode not in ["full", "linear", "adapter"]:
raise ValueError("transfer mode must be one of full, linear, adapter")

Expand All @@ -133,8 +134,18 @@ def main():
for data_dir in args.data_dirs:
# set strand pairs
targets_df = pd.read_csv("%s/targets.txt" % data_dir, sep="\t", index_col=0)

if "strand_pair" in targets_df.columns:
strand_pairs.append(np.array(targets_df.strand_pair))
tmp = np.array(targets_df.strand_pair)
# add additional targets when reduce_mean
if transfer_reduce_mean:
if all(tmp == targets_df.index): # unstranded
tmp = np.append(tmp, [len(tmp)])
else: # stranded
tmp = np.append(tmp, [len(tmp)+1, len(tmp)])

strand_pairs.append(tmp)
print(strand_pairs)

# load train data
train_data.append(
Expand All @@ -145,6 +156,7 @@ def main():
shuffle_buffer=params_train.get("shuffle_buffer", 128),
mode="train",
tfr_pattern=args.tfr_train,
reduce_mean=transfer_reduce_mean
)
)

Expand All @@ -156,6 +168,7 @@ def main():
batch_size=params_train["batch_size"],
mode="eval",
tfr_pattern=args.tfr_eval,
reduce_mean=transfer_reduce_mean
)
)

Expand All @@ -170,6 +183,10 @@ def main():

# initialize model
params_model["verbose"] = False

if transfer_reduce_mean:
params_model['head_human']['units'] += 1

seqnn_model = seqnn.SeqNN(params_model)

# restore
Expand Down

0 comments on commit 7dd4839

Please sign in to comment.