Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: Exceptions raised in an LLMNode don't always halt the pipeline #2086

Open
2 tasks done
dagardner-nv opened this issue Dec 17, 2024 · 0 comments
Open
2 tasks done
Labels
bug Something isn't working

Comments

@dagardner-nv
Copy link
Contributor

Version

25.02, 24.10

Which installation method(s) does this occur on?

Source

Describe the bug.

Originally observed in vulnerability-analysis with Morpheus 24.10, and reproduced in Morpheus 25.02.

Bug happens when:

  1. Using either the HttpServerSourceStage or PydanticHttpStage the key here is that it's a source without a natural stop condition (ex: end of file)
  2. Pipeline includes an LLMEngine and one of the Nodes raises an exception
  3. Input rate for the source stage is infrequent

In this situation:

  1. client sends HTTP Post
  2. pipeline receives message and begins processing
  3. LLMEngine node raises exception
  4. Nothing is printed to the log and pipeline just waits

At this point one of two things will cause the pipeline to exit:

  • User hits Cntrl-C
  • Client sends another HTTP Post

Minimum reproducible example

# Copyright (c) 2021-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import typing

import click
from pydantic import BaseModel

from morpheus.cli.utils import get_log_levels
from morpheus.config import Config
from morpheus.config import ConfigAutoEncoder
from morpheus.config import CppConfig
from morpheus.config import PipelineModes
from morpheus.messages import ControlMessage
from morpheus.messages import MessageMeta
from morpheus.pipeline.linear_pipeline import LinearPipeline
from morpheus.pipeline.stage_decorator import stage
from morpheus.stages.input.appshield_source_stage import AppShieldSourceStage
from morpheus.stages.input.azure_source_stage import AzureSourceStage
from morpheus.stages.input.http_client_source_stage import HttpClientSourceStage
from morpheus.stages.input.http_server_source_stage import HttpServerSourceStage
from morpheus.stages.input.kafka_source_stage import KafkaSourceStage
from morpheus.stages.input.rss_source_stage import RSSSourceStage
from morpheus.stages.preprocess.deserialize_stage import DeserializeStage
from morpheus.utils.logger import configure_logging
from morpheus.cli.utils import parse_log_level
from morpheus.utils.type_utils import get_df_class
from morpheus_llm.llm import LLMEngine
from morpheus_llm.llm import LLMContext
from morpheus_llm.llm import LLMNodeBase
from morpheus_llm.llm.nodes.extracter_node import ExtracterNode
from morpheus_llm.llm.nodes.prompt_template_node import PromptTemplateNode
from morpheus_llm.llm.task_handlers.simple_task_handler import SimpleTaskHandler
from morpheus_llm.stages.llm.llm_engine_stage import LLMEngineStage
from pydantic_http_stage import PydanticHttpStage


class ScanInfoInput(BaseModel):
    countries: list[str]


class LLMPrintNode(LLMNodeBase):

    def __init__(self, raise_error: bool = False):
        super().__init__()
        self._raise_error = raise_error

    def get_input_names(self):
        return ["prompts"]

    async def execute(self, context: LLMContext):

        # Get the keys from the task
        input_dict = context.get_inputs()

        print(f"LLMPrintNode Received:\n{input_dict}")
        if self._raise_error:
            print("LLMPrintNode Raising error")
            raise ValueError("LLMPrintNode Error")

        context.set_output(input_dict[self.get_input_names()[0]])

        return context


def _build_engine(raise_error: bool = False) -> LLMEngine:
    engine = LLMEngine()

    engine.add_node("extracter", node=ExtracterNode())

    engine.add_node("prompts",
                    inputs=["/extracter"],
                    node=PromptTemplateNode(template="What is the capital of {{country}}?", template_format="jinja"))

    engine.add_node("printer", inputs=["/prompts"], node=LLMPrintNode(raise_error=raise_error))

    engine.add_task_handler(inputs=["/printer"], handler=SimpleTaskHandler())

    return engine


@click.command()
@click.option("--use_python", is_flag=True, default=False, show_default=True)
@click.option("--use_pydantic", is_flag=True, default=False, show_default=True)
@click.option("--raise_error", is_flag=True, default=False, show_default=True)
@click.option("--src_type", type=click.Choice(['server', 'client']), default='server', show_default=True)
@click.option("--use_kafka", is_flag=True, default=False, show_default=True)
@click.option("--use_rss", is_flag=True, default=False, show_default=True)
@click.option("--use_appshield", is_flag=True, default=False, show_default=True)
@click.option("--use_aes", is_flag=True, default=False, show_default=True)
@click.option("--bootstrap_servers", type=str, default=os.environ.get('BOOTSTRAP_SERVER', "auto"), show_default=True)
@click.option("--sleep_time", type=float, default=5.0, show_default=True)
@click.option("--log_level",
              default="DEBUG",
              type=click.Choice(get_log_levels(), case_sensitive=False),
              callback=parse_log_level,
              show_default=True,
              help="Specify the logging level to use.")
def run_pipeline(log_level: int,
                 use_python: bool,
                 use_pydantic: bool,
                 raise_error: bool,
                 sleep_time: float,
                 use_kafka: bool,
                 use_rss: bool,
                 use_appshield: bool,
                 use_aes: bool,
                 src_type: str,
                 bootstrap_servers: str):
    # Enable the default logger
    configure_logging(log_level=log_level)

    CppConfig.set_should_use_cpp(not use_python)

    config = Config()
    config.mode = PipelineModes.OTHER
    pipeline = LinearPipeline(config)

    if use_kafka:
        pipeline.set_source(KafkaSourceStage(config, bootstrap_servers=bootstrap_servers, input_topic=["test_pcap"]))
    elif use_rss:
        pipeline.set_source(
            RSSSourceStage(config,
                           feed_input=["https://www.nasa.gov/rss/dyn/breaking_news.rss"],
                           run_indefinitely=True,
                           interval_secs=sleep_time))
    elif use_appshield:
        pipeline.set_source(
            AppShieldSourceStage(config,
                                 input_glob="/tmp/empty_dir/*",
                                 plugins_include=['ldrmodules', 'threadlist', 'envars', 'vadinfo', 'handles'],
                                 cols_include=["SHA256"],
                                 watch_directory=True))
    elif use_aes:
        config.ae = ConfigAutoEncoder()
        config.ae.feature_columns = ["SHA256"]
        pipeline.set_source(AzureSourceStage(config, input_glob="/tmp/empty_dir/*", watch_directory=True))
    else:
        if src_type == "server":
            if use_pydantic:
                pipeline.set_source(
                    PydanticHttpStage(config, bind_address="0.0.0.0", sleep_time=sleep_time,
                                      input_schema=ScanInfoInput))
            else:
                pipeline.set_source(HttpServerSourceStage(config, bind_address="0.0.0.0", sleep_time=sleep_time))
        else:
            pipeline.set_source(
                HttpClientSourceStage(config, url="http://localhost:8080/api/v1/data", sleep_time=sleep_time))

    @stage
    def print_msg(msg: typing.Any) -> MessageMeta:
        if isinstance(msg, MessageMeta):
            print(f"Received:\n{msg.df}")
        elif isinstance(msg, ControlMessage):
            print(f"Received control message:\n{msg.payload().df}")
            msg = msg.payload()
        else:
            print(f"Received:\n{msg}")
            df_class = get_df_class(config.execution_mode)
            df = df_class({"country": msg.countries})
            msg = MessageMeta(df)

        return msg

    completion_task = {"task_type": "completion", "task_dict": {"input_keys": ["country"], }}
    pipeline.add_stage(print_msg(config))
    pipeline.add_stage(DeserializeStage(config, task_type="llm_engine", task_payload=completion_task))
    pipeline.add_stage(LLMEngineStage(config, engine=_build_engine(raise_error=raise_error)))
    pipeline.add_stage(print_msg(config))
    pipeline.run()


if __name__ == "__main__":
    run_pipeline()

Relevant log output

Click here to see error details

[Paste the error here, it will be hidden by default]

Full env printout

Click here to see environment details

[Paste the results of print_env.sh here, it will be hidden by default]

Other/Misc.

No response

Code of Conduct

  • I agree to follow Morpheus' Code of Conduct
  • I have searched the open bugs and have found no duplicates for this bug report
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Todo
Development

No branches or pull requests

1 participant