Skip to content

Commit

Permalink
Add nvidia api (#65)
Browse files Browse the repository at this point in the history
Signed-off-by: Igor Gitman <[email protected]>
  • Loading branch information
Kipok authored Jul 17, 2024
1 parent 9b7f9d5 commit 8f1bd2c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ prompt_template: |-
<user_start>{user}<user_end>
<assistant_start>{generation}
stop_phrases: []
stop_phrases: [] # automatically stops on turn token
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
few_shot_examples:
template: ""

system: |-
You're an expert mathematician. Your goal is to solve the math problem below.
To achieve this you always follow these steps:
Expand All @@ -19,4 +16,4 @@ prompt_template: |-
<user_start>{user}<user_end>
<assistant_start>{generation}
stop_phrases: []
stop_phrases: [] # automatically stops on turn token
46 changes: 34 additions & 12 deletions nemo_skills/inference/server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ def __init__(
):
self.server_host = host
self.server_port = port
self.ssh_server = os.getenv("NEMO_SKILLS_SSH_SERVER", ssh_server)
self.ssh_key_path = os.getenv("NEMO_SKILLS_SSH_KEY_PATH", ssh_key_path)
if ssh_server is None:
self.ssh_server = os.getenv("NEMO_SKILLS_SSH_SERVER")
if ssh_key_path is None:
self.ssh_key_path = os.getenv("NEMO_SKILLS_SSH_KEY_PATH")

if self.ssh_server and self.ssh_key_path:
import sshtunnel_requests
Expand Down Expand Up @@ -209,22 +211,33 @@ def generate(
return outputs


# TODO: this is broken
class OpenAIModel(BaseModel):
def __init__(
self,
model,
model=None,
base_url=None,
api_key=None,
**kwargs,
):
super().__init__(**kwargs)
from openai import OpenAI

if model is None:
model = os.getenv("NEMO_SKILLS_OPENAI_MODEL")
if model is None:
raise ValueError("model argument is required for OpenAI model.")

if base_url is None:
base_url = os.getenv("NEMO_SKILLS_OPENAI_BASE_URL")

if api_key is None:
api_key = os.getenv("OPENAI_API_KEY", api_key)
if base_url is not None and 'api.nvidia.com' in base_url:
api_key = os.getenv("NVIDIA_API_KEY", api_key)
else:
api_key = os.getenv("OPENAI_API_KEY", api_key)

self.model = model
self.client = OpenAI(api_key=api_key)
self.client = OpenAI(api_key=api_key, base_url=base_url)

def generate(
self,
Expand All @@ -241,7 +254,7 @@ def generate(
if stop_phrases is None:
stop_phrases = []
if top_k != 0:
raise ValueError("`top_k` is not supported by OpenAI, please set it to default value `0`.")
raise ValueError("`top_k` is not supported by OpenAI API, please set it to default value `0`.")

outputs = []
for prompt in prompts:
Expand Down Expand Up @@ -293,13 +306,22 @@ def _parse_prompt(self, prompt: str) -> dict:
system_pattern = re.compile(r"<system_start>(.*?)<system_end>", re.DOTALL)
user_pattern = re.compile(r"<user_start>(.*?)<user_end>", re.DOTALL)
generation_pattern = re.compile(r"<assistant_start>(.*)", re.DOTALL)
try:
system_message = system_pattern.search(prompt).group(1)
except AttributeError:
system_message = ""
try:
user_message = user_pattern.search(prompt).group(1)
except AttributeError:
user_message = prompt
messages = [
{"role": "system", "content": system_pattern.search(prompt).group(1)},
{"role": "user", "content": user_pattern.search(prompt).group(1)},
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]
generation_msg = generation_pattern.search(prompt).group(1)
if generation_msg:
messages.append({"role": "assistant", "content": generation_msg})
try:
messages.append({"role": "assistant", "content": generation_pattern.search(prompt).group(1)})
except AttributeError:
pass
return messages


Expand Down

0 comments on commit 8f1bd2c

Please sign in to comment.