Skip to content

Commit

Permalink
Fixed #111
Browse files Browse the repository at this point in the history
  • Loading branch information
khoroshevskyi committed Jan 6, 2024
1 parent 0441886 commit aa8b81e
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 10 deletions.
171 changes: 161 additions & 10 deletions pepdbagent/modules/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import peppy
from peppy.const import SAMPLE_TABLE_INDEX_KEY
from sqlalchemy import select, and_
from sqlalchemy import select, and_, func
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import flag_modified

Expand Down Expand Up @@ -105,7 +105,13 @@ def get(
)

def update(
self, namespace: str, name: str, tag: str, sample_name: str, update_dict: dict
self,
namespace: str,
name: str,
tag: str,
sample_name: str,
update_dict: dict,
full_update: bool = False,
) -> None:
"""
Update one sample in the database
Expand All @@ -117,6 +123,7 @@ def update(
:param update_dict: dictionary with sample data (key: value pairs). e.g.
{"sample_name": "sample1",
"sample_protocol": "sample1 protocol"}
:param full_update: if True, update all sample fields, if False, update only fields from update_dict
:return: None
"""
statement = select(Samples).where(
Expand Down Expand Up @@ -146,10 +153,18 @@ def update(
project_mapping = session.scalar(project_statement)

if sample_mapping:
sample_mapping.sample.update(update_dict)
sample_mapping.sample_name = sample_mapping.sample.get(
project_mapping.config.get(SAMPLE_TABLE_INDEX_KEY, "sample_name")
)
if full_update:
sample_mapping.sample = update_dict
else:
sample_mapping.sample.update(update_dict)
try:
sample_mapping.sample_name = sample_mapping.sample[
project_mapping.config.get(SAMPLE_TABLE_INDEX_KEY, "sample_name")
]
except KeyError:
raise KeyError(
f"Sample index key {project_mapping.config.get(SAMPLE_TABLE_INDEX_KEY, 'sample_name')} not found in sample dict"
)

# This line needed due to: https://github.com/sqlalchemy/sqlalchemy/issues/5218
flag_modified(sample_mapping, "sample")
Expand All @@ -162,8 +177,144 @@ def update(
f"Sample {namespace}/{name}:{tag}?{sample_name} not found in the database"
)

def add(
self,
namespace: str,
name: str,
tag: str,
sample_dict: dict,
overwrite: bool = False,
) -> None:
"""
Add one sample to the project in the database
:param namespace: namespace of the project
:param name: name of the project
:param tag: tag (or version) of the project.
:param overwrite: overwrite sample if it already exists
:param sample_dict: dictionary with sample data (key: value pairs). e.g.
{"sample_name": "sample1",
"sample_protocol": "sample1 protocol"}
:return: None
"""

with Session(self._sa_engine) as session:
project_statement = select(Projects).where(
and_(
Projects.namespace == namespace,
Projects.name == name,
Projects.tag == tag,
)
)
# project mapping is needed to update number of samples, last_update_date and get sample_index_key
project_mapping = session.scalar(project_statement)
try:
sample_name = sample_dict[
project_mapping.config.get(SAMPLE_TABLE_INDEX_KEY, "sample_name")
]
except KeyError:
raise KeyError(
f"Sample index key {project_mapping.config.get(SAMPLE_TABLE_INDEX_KEY, 'sample_name')} not found in sample dict"
)
project_where_statement = (
Samples.project_id
== select(Projects.id)
.where(
and_(
Projects.namespace == namespace,
Projects.name == name,
Projects.tag == tag,
),
)
.scalar_subquery()
)
statement = select(Samples).where(
and_(project_where_statement, Samples.sample_name == sample_name)
)

sample_mapping = session.scalar(statement)
row_number = (
session.execute(
select(func.max(Samples.row_number)).where(project_where_statement)
).one()[0]
or 0
)

if sample_mapping and not overwrite:
raise ValueError(
f"Sample {namespace}/{name}:{tag}?{sample_name} already exists in the database"
)
elif sample_mapping and overwrite:
self.update(
namespace=namespace,
name=name,
tag=tag,
sample_name=sample_name,
update_dict=sample_dict,
full_update=True,
)
return None
else:
sample_mapping = Samples(
sample=sample_dict,
row_number=row_number + 1,
project_id=project_mapping.id,
sample_name=sample_name,
)
project_mapping.number_of_samples += 1
project_mapping.last_update_date = datetime.datetime.now(datetime.timezone.utc)

session.add(sample_mapping)
session.commit()

def delete(
self,
namespace: str,
name: str,
tag: str,
sample_name: str,
) -> None:
"""
Delete one sample from the database
# TODO: add "add sample" method
# TODO: add "delete sample" method
# TODO: check if samples are in correct order if they were deleted or added
# TODO: ensure that this methods update project timestamp
:param namespace: namespace of the project
:param name: name of the project
:param tag: tag (or version) of the project.
:param sample_name: sample_name of the sample
:return: None
"""
statement = select(Samples).where(
and_(
Samples.project_id
== select(Projects.id)
.where(
and_(
Projects.namespace == namespace,
Projects.name == name,
Projects.tag == tag,
),
)
.scalar_subquery(),
Samples.sample_name == sample_name,
)
)
project_statement = select(Projects).where(
and_(
Projects.namespace == namespace,
Projects.name == name,
Projects.tag == tag,
)
)
with Session(self._sa_engine) as session:
sample_mapping = session.scalar(statement)
project_mapping = session.scalar(project_statement)

if sample_mapping:
session.delete(sample_mapping)
project_mapping.number_of_samples -= 1
project_mapping.last_update_date = datetime.datetime.now(datetime.timezone.utc)
session.commit()
else:
raise SampleNotFoundError(
f"Sample {namespace}/{name}:{tag}?{sample_name} not found in the database"
)
83 changes: 83 additions & 0 deletions tests/test_pepagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ProjectNotFoundError,
ProjectNotInFavorites,
ProjectAlreadyInFavorites,
SampleNotFoundError,
)
from .conftest import DNS

Expand Down Expand Up @@ -815,3 +816,85 @@ def test_project_timestamp_was_changed(self, initiate_pepdb_con, namespace, name
annotation2 = initiate_pepdb_con.annotation.get(namespace, name, "default")

assert annotation1.results[0].last_update_date != annotation2.results[0].last_update_date

@pytest.mark.parametrize(
"namespace, name, sample_name",
[
["namespace1", "amendments1", "pig_0h"],
],
)
def test_delete_sample(self, initiate_pepdb_con, namespace, name, sample_name):
one_sample = initiate_pepdb_con.sample.get(namespace, name, sample_name)
assert isinstance(one_sample, peppy.Sample)

initiate_pepdb_con.sample.delete(namespace, name, tag="default", sample_name=sample_name)

with pytest.raises(SampleNotFoundError):
initiate_pepdb_con.sample.get(namespace, name, tag="default", sample_name=sample_name)

@pytest.mark.parametrize(
"namespace, name, tag, sample_dict",
[
[
"namespace1",
"amendments1",
"default",
{
"sample_name": "new_sample",
"time": "new_time",
},
],
],
)
def test_add_sample(self, initiate_pepdb_con, namespace, name, tag, sample_dict):
prj = initiate_pepdb_con.project.get(namespace, name)
initiate_pepdb_con.sample.add(namespace, name, tag, sample_dict)

prj2 = initiate_pepdb_con.project.get(namespace, name)

assert len(prj.samples) + 1 == len(prj2.samples)
assert prj2.samples[-1].sample_name == sample_dict["sample_name"]

@pytest.mark.parametrize(
"namespace, name, tag, sample_dict",
[
[
"namespace1",
"amendments1",
"default",
{
"sample_name": "pig_0h",
"time": "new_time",
},
],
],
)
def test_overwrite_sample(self, initiate_pepdb_con, namespace, name, tag, sample_dict):
assert initiate_pepdb_con.project.get(namespace, name).get_sample("pig_0h").time == "0"
initiate_pepdb_con.sample.add(namespace, name, tag, sample_dict, overwrite=True)

assert (
initiate_pepdb_con.project.get(namespace, name).get_sample("pig_0h").time == "new_time"
)

@pytest.mark.parametrize(
"namespace, name, tag, sample_dict",
[
[
"namespace1",
"amendments1",
"default",
{
"sample_name": "new_sample",
"time": "new_time",
},
],
],
)
def test_delete_and_add(self, initiate_pepdb_con, namespace, name, tag, sample_dict):
prj = initiate_pepdb_con.project.get(namespace, name)
sample_dict = initiate_pepdb_con.sample.get(namespace, name, "pig_0h", raw=True)
initiate_pepdb_con.sample.delete(namespace, name, tag, "pig_0h")
initiate_pepdb_con.sample.add(namespace, name, tag, sample_dict)
prj2 = initiate_pepdb_con.project.get(namespace, name)
assert prj.get_sample("pig_0h").to_dict() == prj2.get_sample("pig_0h").to_dict()

0 comments on commit aa8b81e

Please sign in to comment.