diff --git a/src/napari_ndev/_plate_mapper.py b/src/napari_ndev/_plate_mapper.py index c71f867..a244411 100644 --- a/src/napari_ndev/_plate_mapper.py +++ b/src/napari_ndev/_plate_mapper.py @@ -12,12 +12,22 @@ class PlateMapper: ---------- plate_size : int The size of the plate (e.g., 96, 384). + Defaults to 96. + leading_zeroes : bool + Whether to include leading zeroes in the column labels. + Defaults to False. + treatments : dict + A dictionary mapping treatments to conditions and well ranges. wells : dict A dictionary mapping plate sizes to the number of rows and columns. plate_map : pandas.DataFrame The plate map DataFrame with well labels. - plate_map_pivot : pandas.DataFrame + pivoted_plate_map : pandas.DataFrame The wide-formatted plate map DataFrame with treatments as columns. + Pivots only one treatment at a time. + styled_plate_map : pandas.io.formats.style.Styler + The styled pivoted plate map DataFrame with different background + colors for each unique value. Methods ------- @@ -34,7 +44,7 @@ class PlateMapper: """ - def __init__(self, plate_size=96): + def __init__(self, plate_size=96, treatments=None, leading_zeroes=False, ): """ Initialize a PlateMapper object. @@ -42,9 +52,17 @@ def __init__(self, plate_size=96): ---------- plate_size : int, optional The size of the plate. Defaults to 96. + leading_zeroes : bool, optional + Whether to include leading zeroes in the column labels. + Defaults to False. + treatments : dict, optional + A dictionary mapping treatments to conditions and well ranges. + If provided, the treatments will be assigned to the plate map. + Defaults to None. """ self.plate_size = plate_size + self.leading_zeroes = leading_zeroes self.wells = { 6: (2, 3), 12: (3, 4), @@ -54,7 +72,13 @@ def __init__(self, plate_size=96): 384: (16, 24), } self.plate_map = self.create_empty_plate_map() - self.plate_map_pivot = None + self.pivoted_plate_map = None + self.styled_plate_map = None + + if treatments: + self.assign_treatments(treatments) + # pivot the first key in treatments + self.get_styled_plate_map(next(iter(treatments.keys()))) def create_empty_plate_map(self): """ @@ -69,7 +93,10 @@ def create_empty_plate_map(self): num_rows, num_columns = self.wells[self.plate_size] row_labels = list(string.ascii_uppercase)[:num_rows] - column_labels = list(range(1, num_columns + 1)) + if self.leading_zeroes: + column_labels = [f'{i:02d}' for i in range(1, num_columns + 1)] + else: + column_labels = list(range(1, num_columns + 1)) well_rows = row_labels * num_columns well_rows.sort() # needed to sort the rows correctly @@ -110,13 +137,14 @@ def assign_treatments(self, treatments): well_condition = ( (self.plate_map['row'] >= start_row) & (self.plate_map['row'] <= end_row) - & (self.plate_map['column'] >= start_col) - & (self.plate_map['column'] <= end_col) + & (self.plate_map['column'].astype(int) >= start_col) + & (self.plate_map['column'].astype(int) <= end_col) ) else: row, col = well[0], int(well[1:]) - well_condition = (self.plate_map['row'] == row) & ( - self.plate_map['column'] == col + well_condition = ( + (self.plate_map['row'] == row) + & (self.plate_map['column'] == col) ) self.plate_map.loc[well_condition, treatment] = condition @@ -140,7 +168,7 @@ def get_pivoted_plate_map(self, treatment): plate_map_pivot = self.plate_map.pivot( index='row', columns='column', values=treatment ) - self.plate_map_pivot = plate_map_pivot + self.pivoted_plate_map = plate_map_pivot return plate_map_pivot def get_styled_plate_map(self, treatment, palette='colorblind'): @@ -162,9 +190,9 @@ def get_styled_plate_map(self, treatment, palette='colorblind'): """ from seaborn import color_palette - self.plate_map_pivot = self.get_pivoted_plate_map(treatment) + self.pivoted_plate_map = self.get_pivoted_plate_map(treatment) - unique_values = pd.unique(self.plate_map_pivot.values.flatten()) + unique_values = pd.unique(self.pivoted_plate_map.values.flatten()) unique_values = unique_values[pd.notna(unique_values)] color_palette_hex = color_palette(palette).as_hex() @@ -173,15 +201,16 @@ def get_styled_plate_map(self, treatment, palette='colorblind'): # Use next() to get the next color color_dict = {value: next(palette_cycle) for value in unique_values} - def get_background_color(value): + def get_background_color(value): # pragma: no cover if pd.isna(value): return '' return f'background-color: {color_dict[value]}' plate_map_styled = ( - self.plate_map_pivot.style.applymap(get_background_color) + self.pivoted_plate_map.style.applymap(get_background_color) .set_caption(f'{treatment} Plate Map') .format(lambda x: '' if pd.isna(x) else x) ) + self.styled_plate_map = plate_map_styled return plate_map_styled diff --git a/src/napari_ndev/_tests/test_plate_mapper.py b/src/napari_ndev/_tests/test_plate_mapper.py index 46f26b4..ef50c52 100644 --- a/src/napari_ndev/_tests/test_plate_mapper.py +++ b/src/napari_ndev/_tests/test_plate_mapper.py @@ -11,6 +11,40 @@ def plate_mapper(): return PlateMapper(96) + + + +def test_plate_mapper_init_empty(): + pm = PlateMapper() + plate_map = pm.plate_map + assert isinstance(plate_map, pd.DataFrame) + assert pm.pivoted_plate_map is None + assert pm.styled_plate_map is None + assert len(plate_map) == 96 + assert len(plate_map.columns) == 3 + assert 'row' in plate_map.columns + assert 'A' in plate_map['row'].values + assert 'column' in plate_map.columns + assert 1 in plate_map['column'].values + assert 'well_id' in plate_map.columns + assert 'A1' in plate_map['well_id'].values + +def test_plate_mapper_init_with_plate_size(): + pm = PlateMapper(384) + plate_map = pm.plate_map + assert len(plate_map) == 384 + assert len(plate_map.columns) == 3 + assert 'P' in plate_map['row'].values # 16th letter + assert 24 in plate_map['column'].values + +def test_plate_mapper_leading_zeroes(): + pm = PlateMapper(leading_zeroes=True) + assert 'A' in pm.plate_map['row'].values + assert '01' in pm.plate_map['column'].values + assert '12' in pm.plate_map['column'].values + assert 'A01' in pm.plate_map['well_id'].values + assert 'H12' in pm.plate_map['well_id'].values + @pytest.fixture def treatments(): return { @@ -18,24 +52,13 @@ def treatments(): 'Treatment2': {'Condition3': ['D4:E5']}, } - -def test_plate_mapper_create_empty_plate_map(plate_mapper: PlateMapper): - plate_map_df = plate_mapper.create_empty_plate_map() - - assert isinstance(plate_map_df, pd.DataFrame) - assert len(plate_map_df) == 96 - assert len(plate_map_df.columns) == 3 - assert 'row' in plate_map_df.columns - assert 'column' in plate_map_df.columns - assert 'well_id' in plate_map_df.columns - - -def test_plate_mapper_assign_treatments( - plate_mapper: PlateMapper, treatments: dict[str, dict[str, list[str]]] -): - plate_map = plate_mapper.assign_treatments(treatments) +def test_plate_mapper_init_with_treatments(treatments): + pm = PlateMapper(96, treatments=treatments) + plate_map = pm.plate_map + pivoted_pm = pm.pivoted_plate_map assert isinstance(plate_map, pd.DataFrame) + assert isinstance(pivoted_pm, pd.DataFrame) assert 'Treatment1' in plate_map.columns assert 'Treatment2' in plate_map.columns assert ( @@ -63,6 +86,14 @@ def test_plate_mapper_assign_treatments( == 'Condition3' ) +def test_plate_mapper_init_with_treatments_and_leading_zeroes(treatments): + pm = PlateMapper(96, treatments=treatments, leading_zeroes=True) + plate_map = pm.plate_map + + assert ( + plate_map.loc[plate_map['well_id'] == 'A01', 'Treatment1'].values[0] + ) + def test_plate_mapper_get_pivoted_plate_map( plate_mapper: PlateMapper, treatments: dict[str, dict[str, list[str]]]