diff --git a/src/baskerville/dataset.py b/src/baskerville/dataset.py index c8360c3..20da521 100644 --- a/src/baskerville/dataset.py +++ b/src/baskerville/dataset.py @@ -61,6 +61,7 @@ def __init__( mode: str = "eval", tfr_pattern: str = None, targets_slice_file: str = None, + reduce_mean = False, ): self.data_dir = data_dir self.split_label = split_label @@ -69,6 +70,7 @@ def __init__( self.seq_length_crop = seq_length_crop self.mode = mode self.tfr_pattern = tfr_pattern + self.reduce_mean = reduce_mean # read data parameters data_stats_file = "%s/statistics.json" % self.data_dir @@ -147,11 +149,13 @@ def parse_proto(example_protos): targets = tf.cast(targets, tf.float32) if self.targets_slice is not None: targets = targets[:, self.targets_slice] + if self.reduce_mean: + targets = target_reduce_mean(targets) return sequence, targets return parse_proto - + def make_dataset(self, cycle_length=4): """Make tf.data.Dataset w/ transformations.""" @@ -423,3 +427,24 @@ def untransform_preds1(preds, targets_df, unscale=False, unclip=True): preds = preds * scale return preds + +def target_reduce_mean(tensor): + """ + This function computes the row-wise mean of a 2D tensor, + subtracts the mean from each element in the respective row, + and appends the mean as an additional column. + + Args: + tensor (tf.Tensor): Input tensor of shape (target_length, num_targets) + + Returns: + tf.Tensor: Output tensor of shape (target_length, num_targets+1), + where the last column is the row-wise mean. + """ + + mean_tensor = tf.reduce_mean(tensor, axis=1) # Shape (target_length,) + mean_tensor_expanded = tf.reshape(mean_tensor, (-1, 1)) # Shape (target_length, 1) + difference_tensor = tensor - mean_tensor_expanded # Shape (target_length, num_targets) + result_tensor = tf.concat([difference_tensor, mean_tensor_expanded], axis=1) # Shape (target_length, num_targets+1) + + return result_tensor \ No newline at end of file diff --git a/src/baskerville/scripts/hound_eval.py b/src/baskerville/scripts/hound_eval.py index 2115e17..1f13bee 100755 --- a/src/baskerville/scripts/hound_eval.py +++ b/src/baskerville/scripts/hound_eval.py @@ -37,6 +37,16 @@ """ +# method for inserting adapter for transfer learning +def merge_mean_task(model_0): + outputs = model_0.layers[-1].output + task_other = outputs[:,:,:-1] + task_mean = tf.expand_dims(outputs[:,:,-1], axis=2) + new_task = task_other + task_mean + new_outputs = tf.keras.layers.ReLU()(new_task) + new_model = tf.keras.Model(inputs=model_0.layers[0].input, outputs=new_outputs) + return new_model + def main(): parser = argparse.ArgumentParser(description="Evaluate a trained model.") parser.add_argument( @@ -133,9 +143,24 @@ def main(): params_model = params["model"] params_train = params["train"] + # add task if reduce mean = True + params_transfer = params["transfer"] + transfer_reduce_mean = params_transfer.get("reduce_mean", False) + if transfer_reduce_mean: + params_model['head_human']['units'] += 1 + # set strand pairs if "strand_pair" in targets_df.columns: - params_model["strand_pair"] = [np.array(targets_df.strand_pair)] + tmp = np.array(targets_df.strand_pair) + + # add additional targets when reduce_mean + if transfer_reduce_mean: + if all(tmp == targets_df.index): # unstranded + tmp = np.append(tmp, len(tmp)) + else: # stranded + tmp = np.append(tmp, [len(tmp)+1, len(tmp)]) + + params_model["strand_pair"] = [tmp] # construct eval data eval_data = dataset.SeqDataset( @@ -160,7 +185,9 @@ def main(): seqnn_model.restore(args.model_file, args.head_i) seqnn_model.build_ensemble(args.rc, args.shifts) - + if transfer_reduce_mean: + seqnn_model.ensemble = merge_mean_task(seqnn_model.ensemble) + ####################################################### # evaluate loss_label = params_train.get("loss", "poisson").lower() @@ -168,6 +195,8 @@ def main(): loss_fn = trainer.parse_loss(loss_label, spec_weight=spec_weight) # evaluate + print(seqnn_model.ensemble.output_shape[-1]) + test_loss, test_metric1, test_metric2 = seqnn_model.evaluate( eval_data, loss_label=loss_label, loss_fn=loss_fn ) diff --git a/src/baskerville/scripts/hound_transfer.py b/src/baskerville/scripts/hound_transfer.py index 88ca008..6c38fc2 100755 --- a/src/baskerville/scripts/hound_transfer.py +++ b/src/baskerville/scripts/hound_transfer.py @@ -121,7 +121,8 @@ def main(): transfer_conv_rank = params_transfer.get("conv_latent", 4) transfer_lora_alpha = params_transfer.get("lora_alpha", 16) transfer_locon_alpha = params_transfer.get("locon_alpha", 1) - + transfer_reduce_mean = params_transfer.get("reduce_mean", False) + if transfer_mode not in ["full", "linear", "adapter"]: raise ValueError("transfer mode must be one of full, linear, adapter") @@ -133,8 +134,18 @@ def main(): for data_dir in args.data_dirs: # set strand pairs targets_df = pd.read_csv("%s/targets.txt" % data_dir, sep="\t", index_col=0) + if "strand_pair" in targets_df.columns: - strand_pairs.append(np.array(targets_df.strand_pair)) + tmp = np.array(targets_df.strand_pair) + # add additional targets when reduce_mean + if transfer_reduce_mean: + if all(tmp == targets_df.index): # unstranded + tmp = np.append(tmp, [len(tmp)]) + else: # stranded + tmp = np.append(tmp, [len(tmp)+1, len(tmp)]) + + strand_pairs.append(tmp) + print(strand_pairs) # load train data train_data.append( @@ -145,6 +156,7 @@ def main(): shuffle_buffer=params_train.get("shuffle_buffer", 128), mode="train", tfr_pattern=args.tfr_train, + reduce_mean=transfer_reduce_mean ) ) @@ -156,6 +168,7 @@ def main(): batch_size=params_train["batch_size"], mode="eval", tfr_pattern=args.tfr_eval, + reduce_mean=transfer_reduce_mean ) ) @@ -170,6 +183,10 @@ def main(): # initialize model params_model["verbose"] = False + + if transfer_reduce_mean: + params_model['head_human']['units'] += 1 + seqnn_model = seqnn.SeqNN(params_model) # restore