Skip to content

Commit

Permalink
Updated use case functions and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jenniferjiangkells committed Oct 14, 2024
1 parent c45a018 commit dad5336
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 129 deletions.
74 changes: 37 additions & 37 deletions healthchain/use_cases/cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from healthchain.models import (
CDSRequest,
CDSResponse,
Card,
CDSService,
CDSServiceInformation,
)
Expand Down Expand Up @@ -162,14 +161,29 @@ def cds_discovery(self) -> CDSServiceInformation:

def cds_service(self, id: str, request: CDSRequest) -> CDSResponse:
"""
CDS service endpoint for FastAPI app, should be mounted to /cds-services/{id}
CDS service endpoint for FastAPI app, mounted to /cds-services/{id}
This method handles the execution of a specific CDS service. It validates the
service configuration, checks the input parameters, executes the service
function, and ensures the correct response type is returned.
Args:
id (str): The ID of the CDS service.
id (str): The unique identifier of the CDS service to be executed.
request (CDSRequest): The request object containing the input data for the CDS service.
Returns:
CDSResponse: The response object containing the cards generated by the CDS service.
Raises:
AssertionError: If the service function is not properly configured.
TypeError: If the input or output types do not match the expected types.
Note:
This method performs several checks to ensure the integrity of the service:
1. Verifies that the service API is configured.
2. Validates the signature of the service function.
3. Ensures the service function accepts a CDSRequest as its first argument.
4. Verifies that the service function returns a CDSResponse.
"""
# TODO: can register multiple services and fetch with id

Expand All @@ -178,42 +192,28 @@ def cds_service(self, id: str, request: CDSRequest) -> CDSResponse:
log.warning("CDS 'service_api' not configured, check class init.")
return CDSResponse(cards=[])

# Check service function signature
signature = inspect.signature(self._service_api.func)
assert (
len(signature.parameters) == 2
), f"Incorrect number of arguments: {len(signature.parameters)} {signature}; CDS Service functions currently only accept 'self' and a single input argument."

# Handle different input types
service_input = request
params = iter(inspect.signature(self._service_api.func).parameters.items())
for name, param in params:
if name != "self":
if param.annotation == str:
service_input = request.model_dump_json(exclude_none=True)
elif param.annotation == Dict:
service_input = request.model_dump(exclude_none=True)

# Call the service function
result = self._service_api.func(self, service_input)

# Check the result return type
if result is None:
# Check that the first argument of self._service_api.func is of type CDSRequest
func_signature = inspect.signature(self._service_api.func)
params = list(func_signature.parameters.values())
if len(params) < 2: # Only 'self' parameter
raise AssertionError(
"Service function must have at least one parameter besides 'self'"
)
first_param = params[1] # Skip 'self'
if first_param.annotation == inspect.Parameter.empty:
log.warning(
"CDS 'service_api' returned None, please check function definition."
"Service function parameter has no type annotation. Expected CDSRequest."
)
elif first_param.annotation != CDSRequest:
raise TypeError(
f"Expected first argument of service function to be CDSRequest, but got {first_param.annotation}"
)
return CDSResponse(cards=[])

if not isinstance(result, list):
if isinstance(result, Card):
result = [result]
else:
raise TypeError(f"Expected a list, but got {type(result).__name__}")
# Call the service function
response = self._service_api.func(self, request)

for card in result:
if not isinstance(card, Card):
raise TypeError(
f"Expected a list of 'Card' objects, but found an item of type {type(card).__name__}"
)
# Check that response is of type CDSResponse
if not isinstance(response, CDSResponse):
raise TypeError(f"Expected CDSResponse, but got {type(response).__name__}")

return CDSResponse(cards=result)
return response
38 changes: 32 additions & 6 deletions healthchain/use_cases/clindoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,28 @@ def endpoints(self) -> Dict[str, Endpoint]:

def process_notereader_document(self, request: CdaRequest) -> CdaResponse:
"""
Process the NoteReader document.
Process the NoteReader document using the configured service API.
This method handles the execution of the NoteReader service. It validates the
service configuration, checks the input parameters, executes the service
function, and ensures the correct response type is returned.
Args:
request (CdaRequest): The CdaRequest object containing the document.
request (CdaRequest): The request object containing the CDA document to be processed.
Returns:
CdaResponse: The CdaResponse object containing the processed document.
CdaResponse: The response object containing the processed CDA document.
Raises:
AssertionError: If the service function is not properly configured.
TypeError: If the output type does not match the expected CdaResponse type.
Note:
This method performs several checks to ensure the integrity of the service:
1. Verifies that the service API is configured.
2. Validates the signature of the service function.
3. Ensures the service function accepts a CdaRequest as its argument.
4. Verifies that the service function returns a CdaResponse.
"""
# Check service_api
if self._service_api is None:
Expand All @@ -155,9 +170,20 @@ def process_notereader_document(self, request: CdaRequest) -> CdaResponse:

# Check service function signature
signature = inspect.signature(self._service_api.func)
assert (
len(signature.parameters) == 2
), f"Incorrect number of arguments: {len(signature.parameters)} {signature}; service functions currently only accept 'self' and a single input argument."
params = list(signature.parameters.values())
if len(params) < 2: # Only 'self' parameter
raise AssertionError(
"Service function must have at least one parameter besides 'self'"
)
first_param = params[1] # Skip 'self'
if first_param.annotation == inspect.Parameter.empty:
log.warning(
"Service function parameter has no type annotation. Expected CdaRequest."
)
elif first_param.annotation != CdaRequest:
raise TypeError(
f"Expected first argument of service function to be CdaRequest, but got {first_param.annotation}"
)

# Call the service function
response = self._service_api.func(self, request)
Expand Down
38 changes: 38 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from healthchain.models.requests.cdarequest import CdaRequest
from healthchain.models.responses.cdaresponse import CdaResponse
from healthchain.models.responses.cdsresponse import CDSResponse, Card
from healthchain.service.soap.epiccdsservice import CDSServices
from healthchain.use_cases.cds import (
ClinicalDecisionSupport,
Expand All @@ -27,6 +28,8 @@
from healthchain.use_cases.clindoc import ClinicalDocumentation
from healthchain.workflows import UseCaseType

# TODO: Tidy up fixtures


@pytest.fixture(autouse=True)
def setup_caplog(caplog):
Expand Down Expand Up @@ -141,6 +144,41 @@ def test_cds_request():
return CDSRequest(**cds_dict)


@pytest.fixture
def test_cds_response_single_card():
return CDSResponse(
cards=[
Card(
summary="Test Card",
indicator="info",
source={"label": "Test Source"},
detail="This is a test card for CDS response",
)
]
)


@pytest.fixture
def test_cds_response_empty():
return CDSResponse(cards=[])


@pytest.fixture
def test_cds_response_multiple_cards():
return CDSResponse(
cards=[
Card(
summary="Test Card 1", indicator="info", source={"label": "Test Source"}
),
Card(
summary="Test Card 2",
indicator="warning",
source={"label": "Test Source"},
),
]
)


@pytest.fixture
def mock_client_decorator():
def mock_client_decorator(func):
Expand Down
8 changes: 1 addition & 7 deletions tests/pipeline/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@ def sample_lookup():
}


# @pytest.fixture
# def mock_cda_annotator():
# with patch("healthchain.io.cdaconnector.CdaAnnotator") as mock:
# yield mock


@pytest.fixture
def mock_cda_connector():
with patch("healthchain.io.cdaconnector.CdaConnector") as mock:
Expand Down Expand Up @@ -120,7 +114,7 @@ def configure_pipeline(self, model_path: str) -> None:

@pytest.fixture
def mock_model():
with patch("healthchain.pipeline.components.models.Model") as mock:
with patch("healthchain.pipeline.components.model.Model") as mock:
model_instance = mock.return_value
model_instance.return_value = Document(
data="Processed note",
Expand Down
114 changes: 58 additions & 56 deletions tests/test_cds.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest

from unittest.mock import Mock
from healthchain.use_cases.cds import ClinicalDecisionSupport
from healthchain.models import Card
from healthchain.models.requests.cdsrequest import CDSRequest
from healthchain.models.responses.cdsresponse import CDSResponse


def test_initialization(cds):
Expand All @@ -14,8 +14,8 @@ def test_initialization(cds):
assert "service_mount" in cds.endpoints


def test_cds_discovery_client_not_set():
cds = ClinicalDecisionSupport()
def test_cds_discovery_client_not_set(cds):
cds._client = None
info = cds.cds_discovery()
assert info.services == []

Expand All @@ -27,73 +27,75 @@ def test_cds_discovery(cds):
assert cds_info.services[0].hook == "hook1"


def test_cds_service_no_api_set(test_cds_request):
cds = ClinicalDecisionSupport()
def test_cds_service_valid_response(
cds,
test_cds_request,
test_cds_response_single_card,
test_cds_response_multiple_cards,
):
# Test when everything is valid
def valid_service_func_single_card(self, request: CDSRequest):
return test_cds_response_single_card

cds._service_api = Mock(func=valid_service_func_single_card)

response = cds.cds_service("1", test_cds_request)
assert response.cards == []
assert response == test_cds_response_single_card

def valid_service_func_multiple_cards(self, request: CDSRequest):
return test_cds_response_multiple_cards

def test_cds_service(cds, test_cds_request):
# test returning list of cards
request = test_cds_request
cds._service_api.func.return_value = [
Card(
summary="example",
indicator="info",
source={"label": "test"},
)
]
response = cds.cds_service("1", request)
assert len(response.cards) == 1
assert response.cards[0].summary == "example"
assert response.cards[0].indicator == "info"
cds._service_api = Mock(func=valid_service_func_multiple_cards)

response = cds.cds_service("1", test_cds_request)
assert response == test_cds_response_multiple_cards

# test returning single card
cds._service_api.func.return_value = Card(
summary="example",
indicator="info",
source={"label": "test"},
)
response = cds.cds_service("1", request)
assert len(response.cards) == 1
assert response.cards[0].summary == "example"
assert response.cards[0].indicator == "info"

def test_cds_service_no_service_api(cds, test_cds_request):
# Test when _service_api is None
cds._service_api = None
response = cds.cds_service("test_id", test_cds_request)
assert isinstance(response, CDSResponse)
assert response.cards == []

def test_cds_service_incorrect_return_type(cds, test_cds_request):
request = test_cds_request
cds._service_api.func.return_value = "this is not a valid return type"
with pytest.raises(TypeError):
cds.cds_service("1", request)

def test_cds_service_invalid(cds, test_cds_request, test_cds_response_empty):
# Test when service_api function has invalid signature
def invalid_service_signature(self, invalid_param: str):
return test_cds_response_empty

def func_zero_params():
pass
cds._service_api = Mock(func=invalid_service_signature)

with pytest.raises(
TypeError, match="Expected first argument of service function to be CDSRequest"
):
cds.cds_service("test_id", test_cds_request)

def func_two_params(self, param1, param2):
pass
# Test when service_api function has invalid number of parameters
def invalid_service_num_params(self):
return test_cds_response_empty

cds._service_api = Mock(func=invalid_service_num_params)

def func_one_param(self, param):
pass
with pytest.raises(
AssertionError,
match="Service function must have at least one parameter besides 'self'",
):
cds.cds_service("test_id", test_cds_request)

# Test when service_api function returns invalid type
def invalid_service_return_type(self, request: CDSRequest):
return "Not a CDSResponse"

def test_cds_service_correct_number_of_parameters(cds, test_cds_request):
# Function with one parameter apart from 'self'
cds._service_api = Mock(func=func_one_param)
cds._service_api = Mock(func=invalid_service_return_type)

# Should not raise an assertion error
cds.cds_service("1", test_cds_request)
with pytest.raises(TypeError, match="Expected CDSResponse, but got str"):
cds.cds_service("test_id", test_cds_request)

# test no annotation - should not raise error
def valid_service_func_no_annotation(self, request):
return test_cds_response_empty

def test_cds_service_incorrect_number_of_parameters(cds, test_cds_request):
# Test with zero parameters apart from 'self'
cds._service_api = Mock(func=func_zero_params)
with pytest.raises(AssertionError):
cds.cds_service("1", test_cds_request)
cds._service_api = Mock(func=valid_service_func_no_annotation)

# Test with more than one parameter apart from 'self'
cds._service_api = Mock(func=func_two_params)
with pytest.raises(AssertionError):
cds.cds_service("1", test_cds_request)
assert cds.cds_service("test_id", test_cds_request) == test_cds_response_empty
Loading

0 comments on commit dad5336

Please sign in to comment.