Skip to content

Commit

Permalink
Merge branch 'main' into nemotron-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
shtoshni committed Jul 15, 2024
2 parents 8744d87 + 6c19a72 commit 0008873
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 37 deletions.
24 changes: 24 additions & 0 deletions nemo_skills/inference/prompt/openai/chat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
few_shot_examples:
template: "Question:\n{question}{context}\n\nMy solution:\n{generation}\n\n\n\n\n\n"

system: |-
You're an expert mathematician. Your goal is to solve the math problem below.
To achieve this you always follow these steps:
1. Start by carefully analyzing the given problem.
2. Write DETAILED step-by-step solution.
3. Put the final answer inside \boxed{{}}.
user: |-
Here are some examples of questions and solutions followed by a new question that you need to solve.
{examples}Question:
{question}{context}
# <..._start> and <..._end> are special tokens that are not directly visible to the model.
# They are used to parse the prompt into parts in our inference pipeline.
prompt_template: |-
<system_start>{system}<system_end>
<user_start>{user}<user_end>
<assistant_start>{generation}
stop_phrases: []
22 changes: 22 additions & 0 deletions nemo_skills/inference/prompt/openai/chat_zeroshot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
few_shot_examples:
template: ""

system: |-
You're an expert mathematician. Your goal is to solve the math problem below.
To achieve this you always follow these steps:
1. Start by carefully analyzing the given problem.
2. Write DETAILED step-by-step solution.
3. Put the final answer inside \boxed{{}}.
user: |-
Question:
{question}{context}
# <..._start> and <..._end> are special tokens that are not directly visible to the model.
# They are used to parse the prompt into parts in our inference pipeline.
prompt_template: |-
<system_start>{system}<system_end>
<user_start>{user}<user_end>
<assistant_start>{generation}
stop_phrases: []
12 changes: 0 additions & 12 deletions nemo_skills/inference/prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,6 @@ def build_user_message(self, input_dict: Dict[str, str]) -> str:
user = self.config.user.format(examples=examples, context=context, **input_dict)
return user

def build_structured(self, input_dict: Dict[str, str]) -> List[Dict[str, str]]:
"""Builds a structured representation of the prompt.
The "generation" in the input_dict is a special key that will be
appended to the structured prompt as an assistant message.
"""
structured_prompt = [{"role": "system", "content": self.config.system}] if self.config.system else []
structured_prompt.append({"role": "user", "content": self.build_user_message(input_dict)})
if input_dict.get('generation'):
structured_prompt.append({"role": "assistant", "content": input_dict.get('generation')})
return structured_prompt

def build_string(self, input_dict: Dict[str, str]) -> str:
"""Returns the complete prompt string representation."""
generation = input_dict.get("generation", "")
Expand Down
17 changes: 14 additions & 3 deletions nemo_skills/inference/server/code_execution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nemo_skills.code_execution import CODE_OUTPUT_SEPARATORS, CODE_SEPARATORS, extract_code_to_execute
from nemo_skills.code_execution.sandbox import Sandbox
from nemo_skills.inference.prompt.utils import Prompt
from nemo_skills.inference.server.model import BaseModel, NemoModel, get_model, models, postprocess_output
from nemo_skills.inference.server.model import BaseModel, NemoModel, OpenAIModel, get_model, models, postprocess_output
from nemo_skills.utils import nested_dataclass, python_doc_to_cmd_help

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -109,7 +109,7 @@ def generate(
while len(remaining_ids) > 0:
num_executions += 1
request["prompts"] = [new_outputs[idx]['prompt'] for idx in remaining_ids]
outputs = [output['generation'] for output in self.model.generate(**request)]
outputs = [self._handle_stop_words(output['generation']) for output in self.model.generate(**request)]
new_ids = []
# checking if any of the outputs need code execution and submitting requests in parallel
futures = [None] * len(prompts)
Expand Down Expand Up @@ -180,7 +180,7 @@ def _recover_from_error(self, request, new_output, executor):
results = [None] * self.config.error_recovery.recovery_attempts
for rs in range(self.config.error_recovery.recovery_attempts):
recovery_request['random_seed'] = rs
output = self.model.generate(**recovery_request)[0]['generation']
output = self._handle_stop_words(self.model.generate(**recovery_request)[0]['generation'])
outputs.append(output)
if output.strip().endswith(CODE_SEPARATORS[-1]):
futures[rs] = executor.submit(
Expand Down Expand Up @@ -220,6 +220,17 @@ def _recover_from_error(self, request, new_output, executor):

return most_common

def _handle_stop_words(self, output: str):
"""
OpenAI chat API remove stop word from the output, so we need to add it back
to enable code execution.
"""
if not isinstance(self.model, OpenAIModel):
return output
if output.find(CODE_SEPARATORS[0]) > output.find(CODE_SEPARATORS[-1]):
return output + CODE_SEPARATORS[-1]
return output


def server_params():
"""Returns server documentation (to include in cmd help)."""
Expand Down
27 changes: 20 additions & 7 deletions nemo_skills/inference/server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,9 @@ def generate(
raise ValueError("`top_k` is not supported by OpenAI, please set it to default value `0`.")

outputs = []
for input_dict in input_dicts:
for prompt in prompts:
response = self._send_request(
prompt=prompt,
input_dict=input_dict,
tokens_to_generate=tokens_to_generate,
temperature=temperature,
top_p=top_p,
Expand All @@ -272,7 +271,7 @@ def _send_request(
random_seed: int,
stop_phrases: list[str],
) -> str:
messages = prompt.build_structured(input_dict)
messages = self._parse_prompt(prompt)
response = self.client.chat.completions.create(
model=self.model,
temperature=temperature,
Expand All @@ -284,11 +283,25 @@ def _send_request(
messages=messages,
).choices[0]
output = response.message.content
# adding back stop words
if response.finish_reason == "stop":
output += response.stop_reason
return output

def _parse_prompt(self, prompt: str) -> dict:
"""
OpenAI chat API requires a structured input, so we need to parse the prompt
into a structured list of messages.
"""
system_pattern = re.compile(r"<system_start>(.*?)<system_end>", re.DOTALL)
user_pattern = re.compile(r"<user_start>(.*?)<user_end>", re.DOTALL)
generation_pattern = re.compile(r"<assistant_start>(.*)", re.DOTALL)
messages = [
{"role": "system", "content": system_pattern.search(prompt).group(1)},
{"role": "user", "content": user_pattern.search(prompt).group(1)},
]
generation_msg = generation_pattern.search(prompt).group(1)
if generation_msg:
messages.append({"role": "assistant", "content": generation_msg})
return messages


class VLLMModel(BaseModel):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -430,7 +443,7 @@ def get_model_name_from_server(self) -> str:
models = {
'tensorrt_llm': TensorRTLLMModel,
'nemo': NemoModel,
# 'openai': OpenAIModel,
'openai': OpenAIModel,
'vllm': VLLMModel,
}

Expand Down
27 changes: 12 additions & 15 deletions visualization/utils/decoration.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def starts_with_tag_func_templ(text: str, index: int):
elif '{' not in tag:
returning_index = index + len(tag)
else:
returning_index = text.find('}', index) % (len(text) + 1) + 1
returning_index = text.find('}', index) % (len(text) + 1)

return is_starts_with_tag, returning_index

Expand Down Expand Up @@ -108,20 +108,17 @@ def proccess_plain_text(text: str) -> str:


def preprocess_latex(text: str, escape: bool = True) -> str:
text = (
'\n'
+ text.replace('\\[', '\n$$\n')
.replace('\\]', '\n$$\n')
.replace('\\(', ' $')
.replace('\\)', '$ ')
.replace('=', ' = ')
.replace('+', ' + ')
.replace('-', ' - ')
.replace('*', ' * ')
.replace('/', ' / ')
.replace(' ', ' ')
+ '\n'
)
text = '\n' + text.replace('\\[', '\n$$\n').replace('\\]', '\n$$\n').replace('\\(', ' $').replace('\\)', '$ ')

right_side_operations = ['-', '=', '+', '*', '/']
left_side_operations = ['=', '+', '*', '/']
for op in right_side_operations:
text = text.replace(op + '$', op + ' $')

for op in left_side_operations:
text = text.replace('$' + op, '$ ' + op)

text += '\n'
index = 1
texts = []
start_plain_text_index = -1
Expand Down

0 comments on commit 0008873

Please sign in to comment.