Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add prototype code for validating fragement file #1095

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ venv/

# Include h5ads in tests
!cellxgene_schema_cli/tests/fixtures/h5ads/*
!cellxgene_schema_cli/tests/fixtures/atac_seq/*

# Kozareva big files
cb_annotated_object.RDS
Expand Down
310 changes: 310 additions & 0 deletions cellxgene_schema_cli/cellxgene_schema/atac_seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
import itertools
import logging
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Optional

import anndata as ad
import dask
import dask.dataframe as ddf
import dask.distributed as dd
import pandas as pd
import pysam
from dask import delayed
from dask.delayed import Delayed

logger = logging.getLogger(__name__)

# TODO: these chromosome tables should be calculated from the fasta file?
# location of fasta https://www.gencodegenes.org/human/release_44.html and file name GRCh38.primary_assembly.genome.fa
human_chromosome_by_length = {
"chr1": 248956422,
"chr2": 242193529,
"chr3": 198295559,
"chr4": 190214555,
"chr5": 181538259,
"chr6": 170805979,
"chr7": 159345973,
"chr8": 145138636,
"chr9": 138394717,
"chr10": 133797422,
"chr11": 135086622,
"chr12": 133275309,
"chr13": 114364328,
"chr14": 107043718,
"chr15": 101991189,
"chr16": 90338345,
"chr17": 83257441,
"chr18": 80373285,
"chr19": 58617616,
"chr20": 64444167,
"chr21": 46709983,
"chr22": 50818468,
"chrX": 156040895,
"chrY": 57227415,
"chrM": 16569,
"GL000009.2": 201709,
"GL000194.1": 191469,
"GL000195.1": 182896,
"GL000205.2": 185591,
"GL000213.1": 164239,
"GL000216.2": 176608,
"GL000218.1": 161147,
"GL000219.1": 179198,
"GL000220.1": 161802,
"GL000225.1": 211173,
"KI270442.1": 392061,
"KI270711.1": 42210,
"KI270713.1": 40745,
"KI270721.1": 100316,
"KI270726.1": 43739,
"KI270727.1": 448248,
"KI270728.1": 1872759,
"KI270731.1": 150754,
"KI270733.1": 179772,
"KI270734.1": 165050,
"KI270744.1": 168472,
"KI270750.1": 148850,
}
mouse_chromosome_by_length = {
"chr1": 195154279,
"chr2": 181755017,
"chr3": 159745316,
"chr4": 156860686,
"chr5": 151758149,
"chr6": 149588044,
"chr7": 144995196,
"chr8": 130127694,
"chr9": 124359700,
"chr10": 130530862,
"chr11": 121973369,
"chr12": 120092757,
"chr13": 120883175,
"chr14": 125139656,
"chr15": 104073951,
"chr16": 98008968,
"chr17": 95294699,
"chr18": 90720763,
"chr19": 61420004,
"chrX": 169476592,
"chrY": 91455967,
"chrM": 16299,
"GL456210.1": 169725,
"GL456211.1": 241735,
"GL456212.1": 153618,
"GL456219.1": 175968,
"GL456221.1": 206961,
"GL456239.1": 40056,
"GL456354.1": 195993,
"GL456372.1": 28664,
"GL456381.1": 25871,
"GL456385.1": 35240,
"JH584295.1": 1976,
"JH584296.1": 199368,
"JH584297.1": 205776,
"JH584298.1": 184189,
"JH584299.1": 953012,
"JH584303.1": 158099,
"JH584304.1": 114452,
}

column_ordering = ["chromosome", "start coordinate", "stop coordinate", "barcode", "read support"]
allowed_chromosomes = list(set(itertools.chain(human_chromosome_by_length.keys(), mouse_chromosome_by_length.keys())))
allowed_chromosomes.sort()
chromosome_categories = pd.CategoricalDtype(categories=allowed_chromosomes, ordered=True)
column_types = {
"chromosome": chromosome_categories,
"start coordinate": int,
"stop coordinate": int,
"barcode": str,
"read support": int,
}


def process_fragment(
fragment_file: str, anndata_file: str, generate_index: bool = False, dask_cluster_config: Optional[dict] = None
) -> bool:
"""
Validate the fragment against the anndata file and generate the index if the fragment is valid.

:param str fragment_file: The fragment file to process
:param str anndata_file: The anndata file to validate against
:param bool generate_index: Whether to generate the index for the fragment
:param dask_cluster_config: dask cluster configuration parameters passed to dask.distributed.LocalCluster
"""
with tempfile.TemporaryDirectory() as tempdir:
# unzip the fragment. Subprocess is used here because gzip on the cli uses less memory with comparable speed to
# the python gzip library.
unzipped_file = Path(tempdir) / Path(fragment_file).name.replace(".gz", "")
logger.info(f"Unzipping {fragment_file}")
with open(unzipped_file, "wb") as fp:
subprocess.run(["gunzip", "-c", fragment_file], stdout=fp, check=True)

_dask_cluster_config = dict(silence_logs=logging.ERROR, dashboard_address=None)
if dask_cluster_config:
_dask_cluster_config.update(dask_cluster_config)

with dd.LocalCluster(**_dask_cluster_config) as cluster, dd.Client(cluster):
# convert the fragment to a parquet file
logger.info(f"Converting {fragment_file} to parquet")
parquet_file = Path(tempdir) / Path(fragment_file).name.replace(".gz", ".parquet")
ddf.read_csv(unzipped_file, sep="\t", names=column_ordering, dtype=column_types).to_parquet(
parquet_file, partition_on=["chromosome"], compute=True
)

# remove the unzipped file
logger.debug(f"Removing {unzipped_file}")
unzipped_file.unlink()

errors = validate(parquet_file, anndata_file)
if any(errors):
logger.error("Errors found in Fragment and/or Anndata file")
logger.error(errors)
return False
else:
logger.info("Fragment and Anndata file are valid")

if generate_index:
logger.info(f"Sorting fragment and generating index for {fragment_file}")
index_fragment(fragment_file, parquet_file, tempdir)
logger.debug("cleaning up")


def validate(parquet_file: str, anndata_file: str) -> list[Optional[str]]:
jobs = [
validate_fragment_start_coordinate_greater_than_0(parquet_file),
validate_fragment_barcode_in_adata_index(parquet_file, anndata_file),
validate_fragment_stop_coordinate_within_chromosome(parquet_file, anndata_file),
validate_fragment_stop_greater_than_start_coordinate(parquet_file),
validate_fragment_read_support(parquet_file),
]
return jobs


def validate_fragment_start_coordinate_greater_than_0(parquet_file: Path) -> Optional[str]:
df = ddf.read_parquet(parquet_file, columns=["start coordinate"])
series = df["start coordinate"] > 0
if not series.all().compute():
return "Start coordinate is less than 0"


def validate_fragment_barcode_in_adata_index(parquet_file: Path, anndata_file: Path) -> Optional[str]:
df = ddf.read_parquet(parquet_file, columns=["barcode"])
obs = ad.read_h5ad(anndata_file, backed="r").obs
barcode = set(df.groupby(by="barcode").count().compute().index)
if set(obs.index) != barcode:
return "Barcodes don't match anndata.obs.index"


def validate_fragment_stop_greater_than_start_coordinate(parquet_file: Path) -> Optional[str]:
df = ddf.read_parquet(parquet_file, columns=["start coordinate", "stop coordinate"])
series = df["stop coordinate"] > df["start coordinate"]
if not series.all().compute():
return "Stop coordinate not greater than Start coordinate or Start coordinate is less than 0"


def validate_fragment_stop_coordinate_within_chromosome(parquet_file: Path, anndata_file: Path) -> Optional[str]:
# check that the stop coordinate is within the length of the chromosome
chromome_length_table = pd.DataFrame(
{"NCBITaxon:9606": human_chromosome_by_length, "NCBITaxon:10090": mouse_chromosome_by_length}
)
obs: pd.DataFrame = ad.read_h5ad(anndata_file, backed="r").obs
obs = obs[["organism_ontology_term_id"]] # only the organism_ontology_term_id is needed
unique_organism_ontology_term_id = obs["organism_ontology_term_id"].unique()
df: ddf.DataFrame = ddf.read_parquet(parquet_file, columns=["barcode", "chromosome", "stop coordinate"])
df = df.merge(obs, left_on="barcode", right_index=True)
df = df.merge(chromome_length_table, left_on="chromosome", right_index=True)

for organism_ontology_term_id in unique_organism_ontology_term_id:
df_ = df[df["organism_ontology_term_id"] == organism_ontology_term_id]
df_ = df_["stop coordinate"] <= df_[organism_ontology_term_id]
if not df_.all().compute():
return "Stop coordinate is greater than the length of the chromosome"


def validate_fragment_read_support(parquet_file: Path) -> Optional[str]:
# check that the read support is greater than 0
df = ddf.read_parquet(parquet_file, columns=["read support"], filters=[("read support", "<=", 0)])
if len(df.compute()) != 0:
return "Read support is less than 0"


def detect_chromosomes(parquet_file: Path) -> list[str]:
logger.info("detecting chromosomes")
df = ddf.read_parquet(parquet_file, columns=["chromosome"]).drop_duplicates()
return df["chromosome"].values.compute()


def index_fragment(fragment_file: str, parquet_file: Path, tempdir: tempfile.TemporaryDirectory):
# sort the fragment by chromosome, start coordinate, and stop coordinate, then compress it with bgzip
bgzip_output_file = fragment_file.replace(".gz", ".bgz")
bgzip_output_path = Path(bgzip_output_file)
bgzip_output_path.unlink(missing_ok=True)
bgzip_output_path.touch()
bgzip_write_lock = dd.Lock() # lock to preserver write order

if not shutil.which("bgzip"): # check if bgzip cli is installed
logger.warning("bgzip is not installed, using slower pysam implementation")
write_algorithm = write_algorithm_by_callable["pysam"]
else:
write_algorithm = write_algorithm_by_callable["cli"]

chromosomes = detect_chromosomes(parquet_file)
jobs = prepare_fragment(chromosomes, parquet_file, bgzip_output_file, tempdir, bgzip_write_lock, write_algorithm)
# limit calls to dask.compute to improve performace. The number of jobs to run at once is determined by the
# step variable. If we run all the jobs in the same call to dask.compute, the local cluster hangs.
# TODO: investigate why
step = 4
# print the progress of the jobs
for i in range(0, len(jobs), step):
dask.compute(jobs[i : i + step])

logger.info(f"Fragment sorted and compressed: {bgzip_output_file}")

pysam.tabix_index(bgzip_output_file, preset="bed", force=True)
tabix_output_file = bgzip_output_file + ".tbi"
logger.info(f"Index file generated: {tabix_output_file}")


@delayed
def sort_fragment(parquet_file: Path, write_path: str, chromosome: str) -> Path:
temp_data = Path(write_path) / f"temp_{chromosome}.tsv"
df = ddf.read_parquet(parquet_file, filters=[("chromosome", "==", chromosome)])
df = df[column_ordering]
df = df.sort_values(["start coordinate", "stop coordinate"])

df.to_csv(temp_data, sep="\t", index=False, header=False, mode="w", single_file=True)
return temp_data


@delayed
def write_bgzip_pysam(input_file: str, bgzip_output_file: str, write_lock: dd.Lock):
with write_lock, pysam.libcbgzf.BGZFile(bgzip_output_file, mode="ab") as f_out, open(input_file, "rb") as f_in:
while data := f_in.read(2**20):
f_out.write(data)


@delayed
def write_bgzip_cli(input_file: str, bgzip_output_file: str, write_lock: dd.Lock):
with write_lock:
subprocess.run([f"cat {input_file} | bgzip --threads=8 -c >> {bgzip_output_file}"], shell=True, check=True)


write_algorithm_by_callable = {"pysam": write_bgzip_pysam, "cli": write_bgzip_cli}


def prepare_fragment(
chromosomes: list[str],
parquet_file: Path,
bgzip_output_file: str,
tempdir: tempfile.TemporaryDirectory,
write_lock: dd.Lock,
write_algorithm: callable,
) -> list[Delayed]:
jobs = []
for chromosome in chromosomes:
temp_data = sort_fragment(parquet_file, tempdir, chromosome)
jobs.append(write_algorithm(temp_data, bgzip_output_file, write_lock))
return jobs
36 changes: 29 additions & 7 deletions cellxgene_schema_cli/cellxgene_schema/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import sys

import click
Expand All @@ -13,7 +14,7 @@ def schema_cli():
pass


@click.command(
@schema_cli.command(
name="validate",
short_help="Check that an h5ad follows the cellxgene data integration schema.",
help="Check that an h5ad follows the cellxgene data integration schema. If validation fails this command will "
Expand Down Expand Up @@ -51,7 +52,32 @@ def schema_validate(h5ad_file, add_labels_file, ignore_labels, verbose):
sys.exit(1)


@click.command(
@schema_cli.command(
name="validate-fragment",
short_help="Check that an ATAC SEQ fragment follows the cellxgene data integration schema.",
help="Check that an ATAC SEQ fragment follows the cellxgene data integration schema. If validation fails this "
"command will return an exit status of 1 otherwise 0. When the '--generate-index' tag is present, "
"the command will generate a tabix compatible version of the fragment and tabix index. The generated "
"fragment will have the file suffix .bgz and the index will have the file suffix .bgz.tbi.",
)
@click.argument("h5ad_file", nargs=1, type=click.Path(exists=True, dir_okay=False))
@click.argument("fragment_file", nargs=1, type=click.Path(exists=True, dir_okay=False))
@click.option("-i", "--generate-index", help="Generate index for fragment", is_flag=True)
@click.option("-v", "--verbose", help="When present will set logging level to debug", is_flag=True)
def fragment_validate(h5ad_file, fragment_file, generate_index, verbose):
from .atac_seq import process_fragment

logging.basicConfig(level=logging.ERROR)
if verbose:
logging.getLogger("cellxgene_schema").setLevel(logging.DEBUG)
else:
logging.getLogger("cellxgene_schema").setLevel(logging.INFO)

if not process_fragment(fragment_file, h5ad_file, generate_index=generate_index):
sys.exit(1)


@schema_cli.command(
name="remove-labels",
short_help="Create a copy of an h5ad without portal-added labels",
help="Create a copy of an h5ad without portal-added labels.",
Expand All @@ -78,7 +104,7 @@ def remove_labels(input_file, output_file):
anndata_label_remover.adata.write(output_file, compression="gzip")


@click.command(
@schema_cli.command(
name="migrate",
short_help="Convert an h5ad to the latest schema version.",
help="Convert an h5ad from the previous to latest minor schema version. No validation will be "
Expand All @@ -94,9 +120,5 @@ def migrate(input_file, output_file, collection_id, dataset_id):
migrate(input_file, output_file, collection_id, dataset_id)


schema_cli.add_command(schema_validate)
schema_cli.add_command(migrate)
schema_cli.add_command(remove_labels)

if __name__ == "__main__":
schema_cli()
3 changes: 3 additions & 0 deletions cellxgene_schema_cli/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ scipy<2
semver<4
xxhash<4
matplotlib<4
pysam
dask[array,distributed]<2025
dask-expr
Binary file not shown.
Binary file not shown.
Loading
Loading