Skip to content

Commit

Permalink
Ran Python linting (black)
Browse files Browse the repository at this point in the history
  • Loading branch information
kbestak committed Sep 27, 2023
1 parent ba4ab94 commit f47b8d1
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 102 deletions.
80 changes: 43 additions & 37 deletions bin/apply_clahe.dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from os.path import abspath
from argparse import ArgumentParser as AP
import time
#from memory_profiler import profile

# from memory_profiler import profile
# This API is apparently changing in skimage 1.0 but it's not clear to
# me what the replacement will be, if any. We'll explicitly import
# this so it will break loudly if someone tries this with skimage 1.0.
Expand All @@ -34,7 +35,7 @@

def get_args():
# Script description
description="""Easy-to-use, large scale CLAHE"""
description = """Easy-to-use, large scale CLAHE"""

# Add parser
parser = AP(description=description, formatter_class=argparse.RawDescriptionHelpFormatter)
Expand All @@ -43,10 +44,16 @@ def get_args():
inputs = parser.add_argument_group(title="Required Input", description="Path to required input file")
inputs.add_argument("-r", "--raw", dest="raw", action="store", required=True, help="File path to input image.")
inputs.add_argument("-o", "--output", dest="output", action="store", required=True, help="Path to output image.")
inputs.add_argument("-c", "--channel", dest="channel", action="store", required=True, help="Channel on which CLAHE will be applied")
inputs.add_argument(
"-c", "--channel", dest="channel", action="store", required=True, help="Channel on which CLAHE will be applied"
)
inputs.add_argument("-l", "--cliplimit", dest="clip", action="store", required=True, help="Clip Limit for CLAHE")
inputs.add_argument("--kernel", dest="kernel", action="store", required=False, default=None, help="Kernel size for CLAHE")
inputs.add_argument("-g", "--nbins", dest="nbins", action="store", required=False, default=256, help="Number of bins for CLAHE")
inputs.add_argument(
"--kernel", dest="kernel", action="store", required=False, default=None, help="Kernel size for CLAHE"
)
inputs.add_argument(
"-g", "--nbins", dest="nbins", action="store", required=False, default=256, help="Number of bins for CLAHE"
)
inputs.add_argument("-p", "--pixel-size", dest="pixel_size", action="store", required=True, help="Image pixel size")

arg = parser.parse_args()
Expand All @@ -61,54 +68,58 @@ def get_args():

return arg


def preduce(coords, img_in, img_out):
print(img_in.dtype)
(iy1, ix1), (iy2, ix2) = coords
(oy1, ox1), (oy2, ox2) = np.array(coords) // 2
tile = skimage.img_as_float32(img_in[iy1:iy2, ix1:ix2])
tile = skimage.transform.downscale_local_mean(tile, (2, 2))
tile = dtype_convert(tile, 'uint16')
#tile = dtype_convert(tile, img_in.dtype)
tile = dtype_convert(tile, "uint16")
# tile = dtype_convert(tile, img_in.dtype)
img_out[oy1:oy2, ox1:ox2] = tile


def format_shape(shape):
return "%dx%d" % (shape[1], shape[0])


def subres_tiles(level, level_full_shapes, tile_shapes, outpath, scale):
print(f"\n processing level {level}")
assert level >= 1
num_channels, h, w = level_full_shapes[level]
tshape = tile_shapes[level] or (h, w)
tiff = tifffile.TiffFile(outpath)
zimg = zarr.open(tiff.aszarr(series=0, level=level-1, squeeze=False))
zimg = zarr.open(tiff.aszarr(series=0, level=level - 1, squeeze=False))
for c in range(num_channels):
sys.stdout.write(
f"\r processing channel {c + 1}/{num_channels}"
)
sys.stdout.write(f"\r processing channel {c + 1}/{num_channels}")
sys.stdout.flush()
th = tshape[0] * scale
tw = tshape[1] * scale
for y in range(0, zimg.shape[1], th):
for x in range(0, zimg.shape[2], tw):
a = zimg[c, y:y+th, x:x+tw, 0]
a = skimage.transform.downscale_local_mean(
a, (scale, scale)
)
a = zimg[c, y : y + th, x : x + tw, 0]
a = skimage.transform.downscale_local_mean(a, (scale, scale))
if np.issubdtype(zimg.dtype, np.integer):
a = np.around(a)
a = a.astype('uint16')
a = a.astype("uint16")
yield a


def main(args):
print(f"Head directory = {args.raw}")
print(f"Channel = {args.channel}, ClipLimit = {args.clip}, nbins = {args.nbins}, kernel_size = {args.kernel}, pixel_size = {args.pixel_size}")
print(
f"Channel = {args.channel}, ClipLimit = {args.clip}, nbins = {args.nbins}, kernel_size = {args.kernel}, pixel_size = {args.pixel_size}"
)

#clahe = cv2.createCLAHE(clipLimit = int(args.clip), tileGridSize=tuple(map(int, args.grid)))
# clahe = cv2.createCLAHE(clipLimit = int(args.clip), tileGridSize=tuple(map(int, args.grid)))

img_raw = AI.AICSImage(args.raw)
img_dask = img_raw.get_image_dask_data("CYX").astype('uint16')
adapted = img_dask[args.channel].compute()/65535
adapted = (equalize_adapthist(adapted, kernel_size=args.kernel, clip_limit=args.clip, nbins=args.nbins)*65535).astype('uint16')
img_dask = img_raw.get_image_dask_data("CYX").astype("uint16")
adapted = img_dask[args.channel].compute() / 65535
adapted = (
equalize_adapthist(adapted, kernel_size=args.kernel, clip_limit=args.clip, nbins=args.nbins) * 65535
).astype("uint16")
img_dask[args.channel] = adapted

# construct levels
Expand All @@ -121,7 +132,7 @@ def main(args):
num_channels = img_dask.shape[0]
num_levels = (np.ceil(np.log2(max(base_shape) / tile_size)) + 1).astype(int)
factors = 2 ** np.arange(num_levels)
shapes = (np.ceil(np.array(base_shape) / factors[:,None])).astype(int)
shapes = (np.ceil(np.array(base_shape) / factors[:, None])).astype(int)

print("Pyramid level sizes: ")
for i, shape in enumerate(shapes):
Expand All @@ -137,40 +148,35 @@ def main(args):
level_full_shapes.append((num_channels, shape[0], shape[1]))
level_shapes = shapes
tip_level = np.argmax(np.all(level_shapes < tile_size, axis=1))
tile_shapes = [
(tile_size, tile_size) if i <= tip_level else None
for i in range(len(level_shapes))
]
tile_shapes = [(tile_size, tile_size) if i <= tip_level else None for i in range(len(level_shapes))]

# write pyramid
with tifffile.TiffWriter(args.output, ome=True, bigtiff=True) as tiff:
tiff.write(
data = img_dask,
shape = level_full_shapes[0],
subifds=int(num_levels-1),
data=img_dask,
shape=level_full_shapes[0],
subifds=int(num_levels - 1),
dtype=dtype,
resolution=(10000 / pixel_size, 10000 / pixel_size, "centimeter"),
tile=tile_shapes[0]
tile=tile_shapes[0],
)
for level, (shape, tile_shape) in enumerate(
zip(level_full_shapes[1:], tile_shapes[1:]), 1
):
for level, (shape, tile_shape) in enumerate(zip(level_full_shapes[1:], tile_shapes[1:]), 1):
tiff.write(
data = subres_tiles(level, level_full_shapes, tile_shapes, args.output, scale),
data=subres_tiles(level, level_full_shapes, tile_shapes, args.output, scale),
shape=shape,
subfiletype=1,
dtype=dtype,
tile=tile_shape
tile=tile_shape,
)

# note about metadata: the channels, planes etc were adjusted not to include the removed channels, however
# the channel ids have stayed the same as before removal. E.g if channels 1 and 2 are removed,
# the channel ids in the metadata will skip indices 1 and 2 (channel_id:0, channel_id:3, channel_id:4 ...)
#tifffile.tiffcomment(args.output, to_xml(metadata))
# tifffile.tiffcomment(args.output, to_xml(metadata))
print()


if __name__ == '__main__':
if __name__ == "__main__":
# Read in arguments
args = get_args()

Expand Down
81 changes: 43 additions & 38 deletions bin/collect_QC.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,79 +6,84 @@
import argparse
import pandas as pd


def summarize_spots(spot_table):
## Calculate number of spots per gene
tx_per_gene = spot_table.groupby('gene').count().reset_index()
tx_per_gene = spot_table.groupby("gene").count().reset_index()

## Calculate the total number of spots in spot_table
total_spots = spot_table.shape[0]

return(tx_per_gene,total_spots)
return (tx_per_gene, total_spots)


def summarize_segmasks(mcquant,spots_summary):
def summarize_segmasks(mcquant, spots_summary):
## Calculate the total number of cells (rows) in mcquant
total_cells = mcquant.shape[0]

## Calculate the average segmentation area from column Area in mcquant
avg_area = mcquant['Area'].mean()
avg_area = mcquant["Area"].mean()

## Calculate the % of spots assigned
## Subset mcquant for all columns with _intensity_sum in the column name and sum the column values
spot_assign = mcquant.filter(regex='_intensity_sum').sum(axis=1)
spot_assign = mcquant.filter(regex="_intensity_sum").sum(axis=1)
spot_assign_total = int(sum(spot_assign))
spot_assign_per_cell = total_cells and spot_assign_total / total_cells or 0
#spot_assign_per_cell = spot_assign_total / total_cells
spot_assign_per_cell = total_cells and spot_assign_total / total_cells or 0
# spot_assign_per_cell = spot_assign_total / total_cells
spot_assign_percent = spot_assign_total / spots_summary[1] * 100

return(total_cells,avg_area,spot_assign_per_cell,spot_assign_total,spot_assign_percent)
return (total_cells, avg_area, spot_assign_per_cell, spot_assign_total, spot_assign_percent)


if __name__ == "__main__":
# Write an argparse with input options mcquant_in, spots and output options outdir, sample_id
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--mcquant",
help="mcquant regionprops_table."
)
parser.add_argument(
"-s",
"--spots",
help="Resolve biosciences spot table."
)
parser.add_argument(
"-o",
"--outdir",
help="Output directory."
)
parser.add_argument("-i", "--mcquant", help="mcquant regionprops_table.")
parser.add_argument("-s", "--spots", help="Resolve biosciences spot table.")
parser.add_argument("-o", "--outdir", help="Output directory.")

parser.add_argument(
"-d",
"--sample_id",
help="Sample ID."
)
parser.add_argument("-d", "--sample_id", help="Sample ID.")

parser.add_argument(
"-g",
"--segmentation_method",
help="Segmentation method used."
)
parser.add_argument("-g", "--segmentation_method", help="Segmentation method used.")

args = parser.parse_args()

## Read in mcquant table
mcquant = pd.read_csv(args.mcquant)

## Read in spot table
spots = pd.read_table(args.spots, sep='\t', names=['x', 'y', 'z', 'gene'])
spots = pd.read_table(args.spots, sep="\t", names=["x", "y", "z", "gene"])
spots = spots[~spots.gene.str.contains("Duplicated")]

## Summarize spots table
summary_spots = summarize_spots(spots)
summary_segmentation = summarize_segmasks(mcquant,summary_spots)
summary_segmentation = summarize_segmasks(mcquant, summary_spots)

## Create pandas data frame with one row per parameter and write each value in summary_segmentation to a new row in the data frame
summary_df = pd.DataFrame(columns=['sample_id','segmentation_method','total_cells','avg_area','total_spots','spot_assign_per_cell','spot_assign_total','spot_assign_percent',])
summary_df.loc[0] = [args.sample_id + "_" + args.segmentation_method,args.segmentation_method,summary_segmentation[0],summary_segmentation[1],summary_spots[1],summary_segmentation[2],summary_segmentation[3],summary_segmentation[4]]
summary_df = pd.DataFrame(
columns=[
"sample_id",
"segmentation_method",
"total_cells",
"avg_area",
"total_spots",
"spot_assign_per_cell",
"spot_assign_total",
"spot_assign_percent",
]
)
summary_df.loc[0] = [
args.sample_id + "_" + args.segmentation_method,
args.segmentation_method,
summary_segmentation[0],
summary_segmentation[1],
summary_spots[1],
summary_segmentation[2],
summary_segmentation[3],
summary_segmentation[4],
]

# Write summary_df to a csv file
summary_df.to_csv(f"{args.outdir}/{args.sample_id}.{args.segmentation_method}.spot_QC.csv", header = True, index=False)
summary_df.to_csv(
f"{args.outdir}/{args.sample_id}.{args.segmentation_method}.spot_QC.csv", header=True, index=False
)
44 changes: 17 additions & 27 deletions bin/project_spots.dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,45 @@
from rich.progress import track
from aicsimageio.writers import OmeTiffWriter


# Make a function to project a table of spots with x,y coordinates onto a 2d plane based on reference image shape and add any duplicate spots to increase their pixel value in the output image
def project_spots(spot_table,img):
def project_spots(spot_table, img):
# Initialize an empty image with the same shape as the reference image
img = np.zeros_like(img, dtype= 'int8')
img = np.zeros_like(img, dtype="int8")
# Iterate through each spot in the table
for spot in spot_table.itertuples():
# Add the corresponding spot count to the pixel value at the spot's x,y coordinates
img[spot.y, spot.x] += spot.counts
return img


if __name__ == "__main__":
# Add a python argument parser with options for input, output and image size in x and y
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input",
help="Spot table to project."
)
parser.add_argument(
"-s",
"--sample_id",
help="Sample ID."
)
parser.add_argument(
"-d",
"--img_dims",
dest="img_dims",
help="Corresponding image to get dimensions from."
)
parser.add_argument("-i", "--input", help="Spot table to project.")
parser.add_argument("-s", "--sample_id", help="Sample ID.")
parser.add_argument("-d", "--img_dims", dest="img_dims", help="Corresponding image to get dimensions from.")

args = parser.parse_args()

#spots = pd.read_csv(args.input)
spots = dd.read_table(args.input, sep='\t', names=['x', 'y', 'z', 'gene']).compute()
# spots = pd.read_csv(args.input)
spots = dd.read_table(args.input, sep="\t", names=["x", "y", "z", "gene"]).compute()
img = tifffile.imread(args.img_dims)

spots = spots[["y","x", "gene"]]
spots = spots[["y", "x", "gene"]]

## Filter any genes marked with Duplicated
spots = spots[~spots.gene.str.contains("Duplicated")]

# Sum spots by z-axis
spots_zsum = spots.value_counts().to_frame('counts').reset_index()
spots_zsum = spots.value_counts().to_frame("counts").reset_index()

# Project each gene into a 2d plane and add to list
# Add a printed message that says "Projecting spots for gene X" for each gene in the list
spots_2d_list = [project_spots(spots_zsum[spots_zsum.gene == gene], img) for gene in track(spots_zsum.gene.unique(), description='[green]Projecting spots...')]
spots_2d_list = [
project_spots(spots_zsum[spots_zsum.gene == gene], img)
for gene in track(spots_zsum.gene.unique(), description="[green]Projecting spots...")
]

# Stack images on the c-axis
spot_2d_stack = da.stack(spots_2d_list, axis=0)
Expand All @@ -69,7 +61,5 @@ def project_spots(spot_table,img):
channel_names = spots_zsum.gene.unique().tolist()
pd.DataFrame(channel_names).to_csv(args.sample_id + ".channel_names.csv", index=False, header=False)

#tifffile.imwrite(args.output, spot_2d_stack, metadata={'axes': 'CYX'})
OmeTiffWriter.save(spot_2d_stack,
args.sample_id + ".tiff",
dim_order = "CYX")
# tifffile.imwrite(args.output, spot_2d_stack, metadata={'axes': 'CYX'})
OmeTiffWriter.save(spot_2d_stack, args.sample_id + ".tiff", dim_order="CYX")

0 comments on commit f47b8d1

Please sign in to comment.