From 8f1bd2c54e8bfcf1b6b2c66878b5635c90f12b24 Mon Sep 17 00:00:00 2001 From: Igor Gitman Date: Wed, 17 Jul 2024 14:18:12 -0700 Subject: [PATCH] Add nvidia api (#65) Signed-off-by: Igor Gitman --- .../prompt/openai/{chat.yaml => fewshot.yaml} | 2 +- .../{chat_zeroshot.yaml => zeroshot.yaml} | 5 +- nemo_skills/inference/server/model.py | 46 ++++++++++++++----- 3 files changed, 36 insertions(+), 17 deletions(-) rename nemo_skills/inference/prompt/openai/{chat.yaml => fewshot.yaml} (93%) rename nemo_skills/inference/prompt/openai/{chat_zeroshot.yaml => zeroshot.yaml} (91%) diff --git a/nemo_skills/inference/prompt/openai/chat.yaml b/nemo_skills/inference/prompt/openai/fewshot.yaml similarity index 93% rename from nemo_skills/inference/prompt/openai/chat.yaml rename to nemo_skills/inference/prompt/openai/fewshot.yaml index 3c0e8ee0a..382460267 100644 --- a/nemo_skills/inference/prompt/openai/chat.yaml +++ b/nemo_skills/inference/prompt/openai/fewshot.yaml @@ -21,4 +21,4 @@ prompt_template: |- {user} {generation} -stop_phrases: [] +stop_phrases: [] # automatically stops on turn token diff --git a/nemo_skills/inference/prompt/openai/chat_zeroshot.yaml b/nemo_skills/inference/prompt/openai/zeroshot.yaml similarity index 91% rename from nemo_skills/inference/prompt/openai/chat_zeroshot.yaml rename to nemo_skills/inference/prompt/openai/zeroshot.yaml index cc9e0e366..1706d11a6 100644 --- a/nemo_skills/inference/prompt/openai/chat_zeroshot.yaml +++ b/nemo_skills/inference/prompt/openai/zeroshot.yaml @@ -1,6 +1,3 @@ -few_shot_examples: - template: "" - system: |- You're an expert mathematician. Your goal is to solve the math problem below. To achieve this you always follow these steps: @@ -19,4 +16,4 @@ prompt_template: |- {user} {generation} -stop_phrases: [] +stop_phrases: [] # automatically stops on turn token diff --git a/nemo_skills/inference/server/model.py b/nemo_skills/inference/server/model.py index b67f968c2..5c8a36ebf 100644 --- a/nemo_skills/inference/server/model.py +++ b/nemo_skills/inference/server/model.py @@ -72,8 +72,10 @@ def __init__( ): self.server_host = host self.server_port = port - self.ssh_server = os.getenv("NEMO_SKILLS_SSH_SERVER", ssh_server) - self.ssh_key_path = os.getenv("NEMO_SKILLS_SSH_KEY_PATH", ssh_key_path) + if ssh_server is None: + self.ssh_server = os.getenv("NEMO_SKILLS_SSH_SERVER") + if ssh_key_path is None: + self.ssh_key_path = os.getenv("NEMO_SKILLS_SSH_KEY_PATH") if self.ssh_server and self.ssh_key_path: import sshtunnel_requests @@ -209,22 +211,33 @@ def generate( return outputs -# TODO: this is broken class OpenAIModel(BaseModel): def __init__( self, - model, + model=None, + base_url=None, api_key=None, **kwargs, ): super().__init__(**kwargs) from openai import OpenAI + if model is None: + model = os.getenv("NEMO_SKILLS_OPENAI_MODEL") + if model is None: + raise ValueError("model argument is required for OpenAI model.") + + if base_url is None: + base_url = os.getenv("NEMO_SKILLS_OPENAI_BASE_URL") + if api_key is None: - api_key = os.getenv("OPENAI_API_KEY", api_key) + if base_url is not None and 'api.nvidia.com' in base_url: + api_key = os.getenv("NVIDIA_API_KEY", api_key) + else: + api_key = os.getenv("OPENAI_API_KEY", api_key) self.model = model - self.client = OpenAI(api_key=api_key) + self.client = OpenAI(api_key=api_key, base_url=base_url) def generate( self, @@ -241,7 +254,7 @@ def generate( if stop_phrases is None: stop_phrases = [] if top_k != 0: - raise ValueError("`top_k` is not supported by OpenAI, please set it to default value `0`.") + raise ValueError("`top_k` is not supported by OpenAI API, please set it to default value `0`.") outputs = [] for prompt in prompts: @@ -293,13 +306,22 @@ def _parse_prompt(self, prompt: str) -> dict: system_pattern = re.compile(r"(.*?)", re.DOTALL) user_pattern = re.compile(r"(.*?)", re.DOTALL) generation_pattern = re.compile(r"(.*)", re.DOTALL) + try: + system_message = system_pattern.search(prompt).group(1) + except AttributeError: + system_message = "" + try: + user_message = user_pattern.search(prompt).group(1) + except AttributeError: + user_message = prompt messages = [ - {"role": "system", "content": system_pattern.search(prompt).group(1)}, - {"role": "user", "content": user_pattern.search(prompt).group(1)}, + {"role": "system", "content": system_message}, + {"role": "user", "content": user_message}, ] - generation_msg = generation_pattern.search(prompt).group(1) - if generation_msg: - messages.append({"role": "assistant", "content": generation_msg}) + try: + messages.append({"role": "assistant", "content": generation_pattern.search(prompt).group(1)}) + except AttributeError: + pass return messages