diff --git a/healthchain/pipeline/__init__.py b/healthchain/pipeline/__init__.py index e811479..c368f97 100644 --- a/healthchain/pipeline/__init__.py +++ b/healthchain/pipeline/__init__.py @@ -1,6 +1,6 @@ from healthchain.pipeline.basepipeline import Pipeline from healthchain.pipeline.components.basecomponent import BaseComponent, Component -from healthchain.pipeline.components.models import Model +from healthchain.pipeline.components.model import Model from healthchain.pipeline.components.preprocessors import TextPreProcessor from healthchain.pipeline.components.postprocessors import TextPostProcessor from healthchain.pipeline.genericpipeline import GenericPipeline diff --git a/healthchain/pipeline/components/__init__.py b/healthchain/pipeline/components/__init__.py index 47950ab..48b270e 100644 --- a/healthchain/pipeline/components/__init__.py +++ b/healthchain/pipeline/components/__init__.py @@ -1,6 +1,6 @@ from .preprocessors import TextPreProcessor from .postprocessors import TextPostProcessor -from .models import Model +from .model import Model from .basecomponent import BaseComponent, Component __all__ = [ diff --git a/healthchain/pipeline/components/llm.py b/healthchain/pipeline/components/llm.py new file mode 100644 index 0000000..e0312ca --- /dev/null +++ b/healthchain/pipeline/components/llm.py @@ -0,0 +1,20 @@ +from healthchain.pipeline.components.basecomponent import Component +from healthchain.io.containers import Document +from typing import TypeVar, Generic + +T = TypeVar("T") + + +# TODO: implement this class +class LLM(Component[T], Generic[T]): + def __init__(self, model_name: str): + self.model = model_name + + def load_model(self): + pass + + def load_chain(self): + pass + + def __call__(self, doc: Document) -> Document: + return doc diff --git a/healthchain/pipeline/components/models.py b/healthchain/pipeline/components/model.py similarity index 100% rename from healthchain/pipeline/components/models.py rename to healthchain/pipeline/components/model.py diff --git a/healthchain/pipeline/medicalcodingpipeline.py b/healthchain/pipeline/medicalcodingpipeline.py index 1e495f9..1754c36 100644 --- a/healthchain/pipeline/medicalcodingpipeline.py +++ b/healthchain/pipeline/medicalcodingpipeline.py @@ -2,7 +2,7 @@ from healthchain.pipeline.basepipeline import Pipeline from healthchain.pipeline.components.preprocessors import TextPreProcessor from healthchain.pipeline.components.postprocessors import TextPostProcessor -from healthchain.pipeline.components.models import Model +from healthchain.pipeline.components.model import Model # TODO: Implement this pipeline in full diff --git a/healthchain/pipeline/summarizationpipeline.py b/healthchain/pipeline/summarizationpipeline.py new file mode 100644 index 0000000..62b2c2b --- /dev/null +++ b/healthchain/pipeline/summarizationpipeline.py @@ -0,0 +1,19 @@ +from healthchain.io.cdsfhirconnector import CdsFhirConnector +from healthchain.pipeline.basepipeline import Pipeline +from healthchain.pipeline.components.llm import LLM + + +# TODO: Implement this pipeline in full +class SummarizationPipeline(Pipeline): + def configure_pipeline(self, model_name: str) -> None: + cds_fhir_connector = CdsFhirConnector(hook_name="encounter-discharge") + self.add_input(cds_fhir_connector) + + # Add summarization component + llm = LLM(model_name) + self.add(llm, stage="summarization") + + # Maybe you can have components that create cards + # self.add(CardCreator(), stage="card-creation") + + self.add_output(cds_fhir_connector)