Skip to content

Commit

Permalink
Merge pull request #1315 from dagardner-nv/david-fea-sherlock-nemo-se…
Browse files Browse the repository at this point in the history
…rvice

Docstrings & Tests for NeMoLLMService
  • Loading branch information
dagardner-nv authored Oct 27, 2023
2 parents 732b38a + 747eff9 commit 9d26514
Show file tree
Hide file tree
Showing 10 changed files with 362 additions and 24 deletions.
2 changes: 1 addition & 1 deletion ci/scripts/github/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

function print_env_vars() {
rapids-logger "Environ:"
env | grep -v -E "AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|GH_TOKEN" | sort
env | grep -v -E "AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|GH_TOKEN|NGC_API_KEY" | sort
}

rapids-logger "Env Setup"
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def cli(ctx: click.Context, log_level: int, use_cpp: bool):

morpheus_logger = logging.getLogger("morpheus")

logger = logging.getLogger(__name__)
logger = logging.getLogger('.'.join(__name__.split('.')[:-1]))

# Set the parent logger for all of the llm examples to use morpheus so we can take advantage of configure_logging
logger.parent = morpheus_logger
Expand Down
8 changes: 5 additions & 3 deletions examples/llm/completion/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage
from morpheus.stages.output.in_memory_sink_stage import InMemorySinkStage
from morpheus.stages.preprocess.deserialize_stage import DeserializeStage
from morpheus.utils.concat_df import concat_dataframes

logger = logging.getLogger(__name__)

Expand All @@ -52,15 +53,14 @@ def _build_engine():

engine.add_node("completion", inputs=["/prompts"], node=LLMGenerateNode(llm_client=llm_clinet))

engine.add_task_handler(inputs=["/extracter"], handler=SimpleTaskHandler())
engine.add_task_handler(inputs=["/completion"], handler=SimpleTaskHandler())

return engine


def pipeline(num_threads, pipeline_batch_size, model_max_batch_size, repeat_count: int):

config = Config()
config.mode = PipelineModes.OTHER

# Below properties are specified by the command line
config.num_threads = num_threads
Expand Down Expand Up @@ -107,6 +107,8 @@ def pipeline(num_threads, pipeline_batch_size, model_max_batch_size, repeat_coun

pipe.run()

logger.info("Pipeline complete. Received %s responses", len(sink.get_messages()))
messages = sink.get_messages()
responses = concat_dataframes(messages)
logger.info("Pipeline complete. Received %s responses\n%s", len(messages), responses['response'])

return start_time
119 changes: 101 additions & 18 deletions morpheus/llm/services/nemo_llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,79 @@

logger = logging.getLogger(__name__)

IMPORT_ERROR_MESSAGE = (
"NemoLLM not found. Install it and other additional dependencies by running the following command:\n"
"`mamba env update -n ${CONDA_DEFAULT_ENV} --file docker/conda/environments/cuda11.8_examples.yml`")

try:
from nemollm.api import NemoLLM
except ImportError:
logger.error("NemoLLM not found. Please install NemoLLM to use this service.")
logger.error(IMPORT_ERROR_MESSAGE)


class NeMoLLMClient(LLMClient):
def _verify_nemo_llm():
"""
When NemoLLM is not installed, raise an ImportError with a helpful message, rather than an attribute error.
"""
if 'NemoLLM' not in globals():
raise ImportError(IMPORT_ERROR_MESSAGE)

def __init__(self, parent: "NeMoLLMService", model_name: str, **model_kwargs) -> None:

class NeMoLLMClient(LLMClient):
"""
Client for interacting with a specific model in Nemo. This class should be constructed with the
`NeMoLLMService.get_client` method.
Parameters
----------
parent : NeMoLLMService
The parent service for this client.
model_name : str
The name of the model to interact with.
model_kwargs : dict[str, typing.Any]
Additional keyword arguments to pass to the model when generating text.
"""

def __init__(self, parent: "NeMoLLMService", model_name: str, **model_kwargs: dict[str, typing.Any]) -> None:
super().__init__()
_verify_nemo_llm()

self._parent = parent
self._model_name = model_name
self._model_kwargs = model_kwargs

def generate(self, prompt: str) -> str:
"""
Issue a request to generate a response based on a given prompt.
Parameters
----------
prompt : str
The prompt to generate a response for.
"""
return self.generate_batch([prompt])[0]

async def generate_async(self, prompt: str) -> str:
"""
Issue an asynchronous request to generate a response based on a given prompt.
Parameters
----------
prompt : str
The prompt to generate a response for.
"""
return (await self.generate_batch_async([prompt]))[0]

def generate_batch(self, prompts: list[str]) -> list[str]:

"""
Issue a request to generate a list of responses based on a list of prompts.
Parameters
----------
prompts : list[str]
The prompts to generate responses for.
"""
return typing.cast(
list[str],
self._parent._conn.generate_multiple(model=self._model_name,
Expand All @@ -53,7 +103,14 @@ def generate_batch(self, prompts: list[str]) -> list[str]:
**self._model_kwargs))

async def generate_batch_async(self, prompts: list[str]) -> list[str]:

"""
Issue an asynchronous request to generate a list of responses based on a list of prompts.
Parameters
----------
prompts : list[str]
The prompts to generate responses for.
"""
futures = [
asyncio.wrap_future(
self._parent._conn.generate(self._model_name, p, return_type="async", **self._model_kwargs))
Expand All @@ -62,38 +119,64 @@ async def generate_batch_async(self, prompts: list[str]) -> list[str]:

results = await asyncio.gather(*futures)

return [
typing.cast(str, NemoLLM.post_process_generate_response(r, return_text_completion_only=True))
for r in results
]
responses = []

for result in results:
result = NemoLLM.post_process_generate_response(result, return_text_completion_only=False)
if result.get('status', None) == 'fail':
raise RuntimeError(result.get('msg', 'Unknown error'))

responses.append(result['text'])

return responses


class NeMoLLMService(LLMService):
"""
A service for interacting with NeMo LLM models, this class should be used to create a client for a specific model.
Parameters
----------
api_key : str, optional
The API key for the LLM service, by default None. If `None` the API key will be read from the `NGC_API_KEY`
environment variable. If neither are present an error will be raised.
org_id : str, optional
The organization ID for the LLM service, by default None. If `None` the organization ID will be read from the
`NGC_ORG_ID` environment variable. This value is only required if the account associated with the `api_key` is
a member of multiple NGC organizations.
"""

def __init__(self, *, api_key: str = None, org_id: str = None) -> None:
super().__init__()
_verify_nemo_llm()

api_key = api_key if api_key is not None else os.environ.get("NGC_API_KEY", None)
org_id = org_id if org_id is not None else os.environ.get("NGC_ORG_ID", None)

self._api_key = api_key
self._org_id = org_id

# Do checking on api key

# Class variables
self._conn: NemoLLM = NemoLLM(
# The client must configure the authentication and authorization parameters
# in accordance with the API server security policy.
# Configure Bearer authorization
api_key=self._api_key,
api_key=api_key,

# If you are in more than one LLM-enabled organization, you must
# specify your org ID in the form of a header. This is optional
# if you are only in one LLM-enabled org.
org_id=self._org_id,
org_id=org_id,
)

def get_client(self, model_name: str, **model_kwargs) -> NeMoLLMClient:
def get_client(self, model_name: str, **model_kwargs: dict[str, typing.Any]) -> NeMoLLMClient:
"""
Returns a client for interacting with a specific model. This method is the preferred way to create a client.
Parameters
----------
model_name : str
The name of the model to create a client for.
model_kwargs : dict[str, typing.Any]
Additional keyword arguments to pass to the model when generating text.
"""

return NeMoLLMClient(self, model_name, **model_kwargs)
2 changes: 1 addition & 1 deletion morpheus/stages/input/arxiv_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from langchain.schema import Document

IMPORT_ERROR_MESSAGE = (
"ArxivSource requires additional dependencies to be installed. Install them by runnign the following command: "
"ArxivSource requires additional dependencies to be installed. Install them by running the following command: "
"`mamba env update -n ${CONDA_DEFAULT_ENV} --file docker/conda/environments/cuda11.8_examples.yml`")


Expand Down
18 changes: 18 additions & 0 deletions tests/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,24 @@ def import_or_skip(modname: str,
raise


# pylint: disable=inconsistent-return-statements
def require_env_variable(varname: str, reason: str, fail_missing: bool = False) -> str:
"""
Checks if the given environment variable is set, and returns its value if it is. If the variable is not set, and
`fail_missing` is False the test will ve skipped, otherwise a `RuntimeError` will be raised.
"""
try:
return os.environ[varname]
except KeyError as e:
if fail_missing:
raise RuntimeError(reason) from e

pytest.skip(reason=reason)


# pylint: enable=inconsistent-return-statements


def make_url(port: int, endpoint: str) -> str:
if not endpoint.startswith("/"):
endpoint = "/" + endpoint
Expand Down
40 changes: 40 additions & 0 deletions tests/llm/services/nemo_llm/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from _utils import import_or_skip
from _utils import require_env_variable


@pytest.fixture(name="nemollm", autouse=True, scope='session')
def nemollm_fixture(fail_missing: bool):
"""
All of the tests in this subdir require nemollm
"""
skip_reason = ("Tests for the NeMoLLMService require the nemollm package to be installed, to install this run:\n"
"`mamba env update -n ${CONDA_DEFAULT_ENV} --file docker/conda/environments/cuda11.8_examples.yml`")
yield import_or_skip("nemollm", reason=skip_reason, fail_missing=fail_missing)


@pytest.fixture(name="ngc_api_key", scope='session')
def ngc_api_key_fixture(fail_missing: bool):
"""
Integration tests require an NGC API key.
"""
yield require_env_variable(
varname="NGC_API_KEY",
reason="nemo integration tests require the `NGC_API_KEY` environment variavble to be defined.",
fail_missing=fail_missing)
56 changes: 56 additions & 0 deletions tests/llm/services/nemo_llm/test_nemo_llm_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

from morpheus.llm.services.nemo_llm_service import NeMoLLMClient


def _make_mock_nemo_llm():
mock_nemo_llm = mock.MagicMock()
mock_nemo_llm.return_value = mock_nemo_llm
mock_nemo_llm.generate_multiple.return_value = ["test_output"]
return mock_nemo_llm


def _make_mock_nemo_service():
mock_nemo_llm = _make_mock_nemo_llm()
mock_nemo_service = mock.MagicMock()
mock_nemo_service.return_value = mock_nemo_service
mock_nemo_service._conn = mock_nemo_llm
return (mock_nemo_service, mock_nemo_llm)


def test_generate():
(mock_nemo_service, mock_nemo_llm) = _make_mock_nemo_service()

client = NeMoLLMClient(mock_nemo_service, "test_model", additional_arg="test_arg")
assert client.generate("test_prompt") == "test_output"
mock_nemo_llm.generate_multiple.assert_called_once_with(model="test_model",
prompts=["test_prompt"],
return_type="text",
additional_arg="test_arg")


def test_generate_batch():
(mock_nemo_service, mock_nemo_llm) = _make_mock_nemo_service()
mock_nemo_llm.generate_multiple.return_value = ["output1", "output2"]

client = NeMoLLMClient(mock_nemo_service, "test_model", additional_arg="test_arg")
assert client.generate_batch(["prompt1", "prompt2"]) == ["output1", "output2"]
mock_nemo_llm.generate_multiple.assert_called_once_with(model="test_model",
prompts=["prompt1", "prompt2"],
return_type="text",
additional_arg="test_arg")
Loading

0 comments on commit 9d26514

Please sign in to comment.