Skip to content

Commit

Permalink
[chore] add multithreaded chat benchmarking code
Browse files Browse the repository at this point in the history
  • Loading branch information
yashbonde committed Dec 19, 2024
1 parent 091a33d commit 97009ad
Show file tree
Hide file tree
Showing 5 changed files with 280 additions and 2 deletions.
275 changes: 275 additions & 0 deletions benchmarks/threaded_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
# Copyright © 2024- Frello Technology Private Limited

from tuneapi.types import Thread, ModelInterface, human, system
from tuneapi.utils import from_json

# import re
import queue
import threading
from tqdm import trange
from typing import List, Optional
from dataclasses import dataclass

from concurrent.futures import ThreadPoolExecutor, as_completed, Future


@dataclass
class Task:
index: int
model: ModelInterface
prompt: Thread
retry_count: int = 0


@dataclass
class Result:
index: int
data: any
success: bool
error: Optional[Exception] = None


def bulk_chat(
model: ModelInterface,
prompts: List[Thread],
post_logic: Optional[callable] = None,
max_threads: int = 10,
retry: int = 3,
pbar=True,
):
task_channel = queue.Queue()
result_channel = queue.Queue()

# Initialize results container
results = [None for _ in range(len(prompts))]

def worker():
while True:
try:
task: Task = task_channel.get(timeout=1)
if task is None: # Poison pill
break

try:
# print(">")
out = task.model.chat(task.prompt)
if post_logic:
out = post_logic(out)
result_channel.put(Result(task.index, out, True))
except Exception as e:
if task.retry_count < retry:
# Create new model instance for retry
nm = model.__class__(
id=model.model_id,
base_url=model.base_url,
extra_headers=model.extra_headers,
)
nm.set_api_token(model.api_token)
# Increment retry count and requeue
task_channel.put(
Task(task.index, nm, task.prompt, task.retry_count + 1)
)
else:
# If we've exhausted retries, store the error
result_channel.put(Result(task.index, e, False, e))
finally:
task_channel.task_done()
except queue.Empty:
continue

# Create and start worker threads
workers = []
for _ in range(max_threads):
t = threading.Thread(target=worker)
t.start()
workers.append(t)

# Initialize progress bar
_pbar = trange(len(prompts), desc="Processing", unit=" input") if pbar else None

# Queue initial tasks
for i, p in enumerate(prompts):
nm = model.__class__(
id=model.model_id,
base_url=model.base_url,
extra_headers=model.extra_headers,
)
nm.set_api_token(model.api_token)
task_channel.put(Task(i, nm, p))

# Process results
completed = 0
while completed < len(prompts):
try:
result = result_channel.get(timeout=1)
results[result.index] = result.data if result.success else result.error
if _pbar:
_pbar.update(1)
completed += 1
result_channel.task_done()
except queue.Empty:
continue

# Cleanup
for _ in workers:
task_channel.put(None) # Send poison pills
for w in workers:
w.join()

if _pbar:
_pbar.close()

return results


prompts = []
for x in range(100):
prompts.append(
Thread(
system(
"""## Response schmea
Respond with the following schema **ensure sending <json> and </json> tags**.
```
<json>
{{
"code": "...",
}}
</json>
```
"""
),
human(
f"what is the value of 10 ^ {max(x, 10)}. Write down the answer in Indian number system. given in coe tag."
),
)
)


import random


def get_tagged_section(tag: str, input_str: str):
if random.random() > 0.5:
import re

html_pattern = re.compile("<" + tag + ">(.*?)</" + tag + ">", re.DOTALL)
match = html_pattern.search(input_str)
if match:
return match.group(1)

md_pattern = re.compile("```" + tag + "(.*?)```", re.DOTALL)
match = md_pattern.search(input_str)
if match:
return match.group(1)
return None


post_logic = lambda out: from_json(get_tagged_section("json", out))["code"]

# import re
from tuneapi.apis import Openai

from time import time

st = time()
out = bulk_chat(
Openai(),
prompts,
post_logic=post_logic,
max_threads=5,
pbar=True,
retry=3,
)
print(out)
print(len(out))
print(len([x for x in out if x is None]))

print(f"Endtime: {time() - st:0.4f}s")

print("\n\n\n")


from uuid import uuid4
from typing import Generator


def bulk_chat_2(
model: ModelInterface,
prompts: List[Thread],
post_logic: Optional[callable] = None,
max_threads: int = 10,
retry: int = 3,
pbar=True,
):
def _chat(model: ModelInterface, prompt: Thread):
out = model.chat(prompt)
if post_logic:
return post_logic(out) # The mapped function
return out

# create all the inputs
retry = int(retry) # so False becomes 0 and True becomes 1
inputs = []
for p in prompts:
nm = model.__class__(
id=model.model_id,
base_url=model.base_url,
extra_headers=model.extra_headers,
)
nm.set_api_token(model.api_token)
inputs.append((nm, p))

# run the executor
_name = str(uuid4())
if isinstance(inputs, Generator):
inputs = list(inputs)

results = [None for _ in range(len(inputs))]
_pbar = trange(len(inputs), desc="Processing", unit=" input") if pbar else None
with ThreadPoolExecutor(max_workers=max_threads, thread_name_prefix=_name) as exe:
_fn = lambda x: _chat(*x)
loop_cntr = 0
done = False
inputs = [(i, x) for i, x in enumerate(inputs)]

# loop over things
while not done:
failed = []
_pbar.set_description(f"Starting master loop #{loop_cntr:02d}")
futures = {exe.submit(_fn, x): (i, x) for (i, x) in inputs}
for fut in as_completed(futures):
# print(">")
i, x = futures[fut] # indexes
try:
res = fut.result()
if _pbar:
_pbar.update(1)
results[i] = res
except Exception as e:
failed.append((i, x))

# overide values for the loop
inputs = failed

# the done flag
loop_cntr += 1
done = len(failed) == 0 or loop_cntr > retry
return results


st = time()
out = bulk_chat_2(
Openai(),
prompts,
post_logic=post_logic,
max_threads=5,
pbar=True,
retry=3,
)
print(out)
print(len(out))
print(len([x for x in out if x is None]))

print(f"Endtime: {time() - st:0.4f}s")
3 changes: 2 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Changelog
=========

This package is already used in production at Tune AI, please do not wait for release ``1.x.x`` for stability, or expect
to reach ``1.0.0``. We do not follow the general rules of semantic versioning, and there can be breaking changes between
to reach ``1.0.0``. We **do not follow the general rules** of semantic versioning, and there can be breaking changes between
minor versions.

All relevant steps to be taken will be mentioned here.
Expand All @@ -12,6 +12,7 @@ All relevant steps to be taken will be mentioned here.

- ``distributed_chat`` functionality in ``tuneapi.apis.turbo`` support. In all APIs search for ``model.distributed_chat()``
method. This enables **fault tolerant LLM API calls**.
- Moved ``tuneapi.types.experimental`` to ``tuneapi.types.evals``

0.5.13
-----
Expand Down
2 changes: 2 additions & 0 deletions tests/test_tree.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright © 2024- Frello Technology Private Limited

import tuneapi.types as tt

from unittest import TestCase, main as ut_main
Expand Down
2 changes: 1 addition & 1 deletion tuneapi/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
function_resp,
)

from tuneapi.types.experimental import (
from tuneapi.types.evals import (
Evals,
)
File renamed without changes.

0 comments on commit 97009ad

Please sign in to comment.