Skip to content

Commit

Permalink
add pyproject.toml
Browse files Browse the repository at this point in the history
  • Loading branch information
Josephrp committed Aug 7, 2024
1 parent 91701fc commit 5110795
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
python_sources()
python_sources(
name="llama-index-llms-githubllm",
sources=["*.py", "*.pyi"],
)
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import os
import logging
import time
from typing import Any, Dict, List, Sequence, Generator

from typing import Any, Dict, List, Sequence, Generator, AsyncGenerator
import aiohttp
import requests
from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseAsyncGen,
CompletionResponseGen,
LLMMetadata,
)
Expand Down Expand Up @@ -67,8 +69,8 @@ class GithubLLM(CustomLLM):

_rate_limit_reset_time: float = PrivateAttr(default=0)
_request_count: int = PrivateAttr(default=0)
_max_requests_per_minute: int = PrivateAttr(default=15) # Adjust based on your tier
_max_requests_per_day: int = PrivateAttr(default=150) # Adjust based on your tier
_max_requests_per_minute: int = PrivateAttr(default=15)
_max_requests_per_day: int = PrivateAttr(default=150)

SUPPORTED_MODELS = [
"AI21-Jamba-Instruct",
Expand All @@ -95,6 +97,29 @@ class GithubLLM(CustomLLM):
"phi-3-small-instruct-8k",
]

MODEL_TOKEN_LIMITS = {
"AI21-Jamba-Instruct": {"input": 72000, "output": 4000},
"cohere-command-r": {"input": 131000, "output": 4000},
"cohere-command-r-plus": {"input": 131000, "output": 4000},
"meta-llama-3-70b-instruct": {"input": 8000, "output": 4000},
"meta-llama-3-8b-instruct": {"input": 8000, "output": 4000},
"meta-llama-3.1-405b-instruct": {"input": 131000, "output": 4000},
"meta-llama-3.1-70b-instruct": {"input": 131000, "output": 4000},
"meta-llama-3.1-8b-instruct": {"input": 131000, "output": 4000},
"mistral-large": {"input": 33000, "output": 4000},
"mistral-large-2407": {"input": 131000, "output": 4000},
"mistral-nemo": {"input": 131000, "output": 4000},
"mistral-small": {"input": 33000, "output": 4000},
"gpt-4o": {"input": 131000, "output": 4000},
"gpt-4o-mini": {"input": 131000, "output": 4000},
"phi-3-medium-instruct-128k": {"input": 131000, "output": 4000},
"phi-3-medium-instruct-4k": {"input": 4000, "output": 4000},
"phi-3-mini-instruct-128k": {"input": 131000, "output": 4000},
"phi-3-mini-instruct-4k": {"input": 4000, "output": 4000},
"phi-3-small-instruct-128k": {"input": 131000, "output": 4000},
"phi-3-small-instruct-8k": {"input": 131000, "output": 4000},
}

@validator("model")
def validate_model(cls, v):
if v.lower() not in [model.lower() for model in cls.SUPPORTED_MODELS]:
Expand All @@ -106,9 +131,12 @@ def validate_model(cls, v):
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
model_limits = self.MODEL_TOKEN_LIMITS.get(
self.model, {"input": 4096, "output": 4000}
)
return LLMMetadata(
context_window=4096, # Assuming a default context window, adjust as needed
num_output=self.max_tokens or 256, # Adjust default as needed
context_window=model_limits["input"],
num_output=self.max_tokens or model_limits["output"],
model_name=self.model,
)

Expand Down Expand Up @@ -144,36 +172,36 @@ def _call_api(
response.raise_for_status()
return response

def _prepare_messages(
self, messages: Sequence[ChatMessage]
) -> List[Dict[str, str]]:
"""Prepare messages for API call, including system prompt if present."""
formatted_messages = []
if self.system_prompt:
formatted_messages.append({"role": "system", "content": self.system_prompt})
formatted_messages.extend(
[{"role": m.role, "content": m.content} for m in messages]
)
return formatted_messages

@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
"""Generate a completion."""
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt},
]

messages = self._prepare_messages([ChatMessage(role="user", content=prompt)])
response_content = self._call_llm(messages, **kwargs)
return CompletionResponse(text=response_content)

@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
"""Stream a completion."""
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt},
]

messages = self._prepare_messages([ChatMessage(role="user", content=prompt)])
for chunk in self._stream_llm(messages, **kwargs):
yield CompletionResponse(text=chunk, delta=chunk)

@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
"""Generate a chat response."""
formatted_messages = [{"role": "system", "content": self.system_prompt}]
formatted_messages.extend(
[{"role": m.role, "content": m.content} for m in messages]
)

formatted_messages = self._prepare_messages(messages)
response_content = self._call_llm(formatted_messages, **kwargs)
return ChatResponse(
message=ChatMessage(role="assistant", content=response_content)
Expand All @@ -184,23 +212,25 @@ def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
"""Stream a chat response."""
formatted_messages = [{"role": "system", "content": self.system_prompt}]
formatted_messages.extend(
[{"role": m.role, "content": m.content} for m in messages]
)

formatted_messages = self._prepare_messages(messages)
for chunk in self._stream_llm(formatted_messages, **kwargs):
yield ChatResponse(
message=ChatMessage(role="assistant", content=chunk), delta=chunk
)

def _call_llm(self, messages: List[Dict[str, str]], **kwargs: Any) -> str:
"""Call the LLM with fallback to Azure if rate limited."""
model_limits = self.MODEL_TOKEN_LIMITS.get(
self.model, {"input": 4096, "output": 4000}
)
max_tokens = min(
self.max_tokens or model_limits["output"], model_limits["output"]
)

data = {
"messages": messages,
"model": self.model,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"max_tokens": max_tokens,
**kwargs,
}

Expand Down Expand Up @@ -239,12 +269,18 @@ def _call_llm(self, messages: List[Dict[str, str]], **kwargs: Any) -> str:
def _stream_llm(
self, messages: List[Dict[str, str]], **kwargs: Any
) -> Generator[str, None, None]:
"""Stream from the LLM with fallback to Azure if rate limited."""
model_limits = self.MODEL_TOKEN_LIMITS.get(
self.model, {"input": 4096, "output": 4000}
)
max_tokens = min(
self.max_tokens or model_limits["output"], model_limits["output"]
)

data = {
"messages": messages,
"model": self.model,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"max_tokens": max_tokens,
"stream": True,
**kwargs,
}
Expand Down Expand Up @@ -290,6 +326,178 @@ def _stream_llm(
else:
raise ValueError("Rate limit reached and Azure fallback is disabled.")

async def _async_call_api(
self,
endpoint_url: str,
headers: Dict[str, str],
data: Dict[str, Any],
stream: bool = False,
) -> Any:
"""Make an asynchronous API call to either GitHub or Azure."""
async with aiohttp.ClientSession() as session:
if stream:
async with session.post(
endpoint_url, headers=headers, json=data
) as response:
response.raise_for_status()
async for line in response.content:
if line:
yield line.decode("utf-8")
else:
async with session.post(
endpoint_url, headers=headers, json=data
) as response:
response.raise_for_status()
return await response.json()

@llm_completion_callback()
async def acomplete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
"""Generate an asynchronous completion."""
messages = self._prepare_messages([ChatMessage(role="user", content=prompt)])
response_content = await self._async_call_llm(messages, **kwargs)
return CompletionResponse(text=response_content)

@llm_completion_callback()
async def astream_complete(
self, prompt: str, **kwargs: Any
) -> CompletionResponseAsyncGen:
"""Stream an asynchronous completion."""
messages = self._prepare_messages([ChatMessage(role="user", content=prompt)])
async for chunk in self._async_stream_llm(messages, **kwargs):
yield CompletionResponse(text=chunk, delta=chunk)

@llm_chat_callback()
async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponse:
"""Generate an asynchronous chat response."""
formatted_messages = self._prepare_messages(messages)
response_content = await self._async_call_llm(formatted_messages, **kwargs)
return ChatResponse(
message=ChatMessage(role="assistant", content=response_content)
)

@llm_chat_callback()
async def astream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
"""Stream an asynchronous chat response."""
formatted_messages = self._prepare_messages(messages)
async for chunk in self._async_stream_llm(formatted_messages, **kwargs):
yield ChatResponse(
message=ChatMessage(role="assistant", content=chunk), delta=chunk
)

async def _async_call_llm(
self, messages: List[Dict[str, str]], **kwargs: Any
) -> str:
model_limits = self.MODEL_TOKEN_LIMITS.get(
self.model, {"input": 4096, "output": 4000}
)
max_tokens = min(
self.max_tokens or model_limits["output"], model_limits["output"]
)

data = {
"messages": messages,
"model": self.model,
"temperature": self.temperature,
"max_tokens": max_tokens,
**kwargs,
}

if self._check_rate_limit():
try:
github_token = os.environ.get("GITHUB_TOKEN")
if not github_token:
raise ValueError("GITHUB_TOKEN environment variable is not set.")

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {github_token}",
}

response = await self._async_call_api(
self.github_endpoint_url, headers, data
)
self._increment_request_count()
return response["choices"][0]["message"]["content"]
except (aiohttp.ClientError, ValueError) as e:
logger.warning(f"GitHub API call failed: {e!s}. Falling back to Azure.")

if self.use_azure_fallback:
azure_api_key = os.environ.get("AZURE_API_KEY")
if not azure_api_key:
raise ValueError("AZURE_API_KEY environment variable is not set.")

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {azure_api_key}",
}

response = await self._async_call_api(
self.github_endpoint_url, headers, data
)
return response["choices"][0]["message"]["content"]
else:
raise ValueError("Rate limit reached and Azure fallback is disabled.")

async def _async_stream_llm(
self, messages: List[Dict[str, str]], **kwargs: Any
) -> AsyncGenerator[str, None]:
model_limits = self.MODEL_TOKEN_LIMITS.get(
self.model, {"input": 4096, "output": 4000}
)
max_tokens = min(
self.max_tokens or model_limits["output"], model_limits["output"]
)

data = {
"messages": messages,
"model": self.model,
"temperature": self.temperature,
"max_tokens": max_tokens,
"stream": True,
**kwargs,
}

if self._check_rate_limit():
try:
github_token = os.environ.get("GITHUB_TOKEN")
if not github_token:
raise ValueError("GITHUB_TOKEN environment variable is not set.")

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {github_token}",
}

async for chunk in self._async_call_api(
self.github_endpoint_url, headers, data, stream=True
):
yield chunk
self._increment_request_count()
return
except (aiohttp.ClientError, ValueError) as e:
logger.warning(f"GitHub API call failed: {e!s}. Falling back to Azure.")

if self.use_azure_fallback:
azure_api_key = os.environ.get("AZURE_API_KEY")
if not azure_api_key:
raise ValueError("AZURE_API_KEY environment variable is not set.")

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {azure_api_key}",
}

async for chunk in self._async_call_api(
self.github_endpoint_url, headers, data, stream=True
):
yield chunk
else:
raise ValueError("Rate limit reached and Azure fallback is disabled.")

@classmethod
def class_name(cls) -> str:
"""Get the name of the class."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ check-hidden = true
# work through many typos (--write-changes and --interactive will help)
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"

[tool.llamahub]
contains_example = false
import_path = "llama_index.llms.githubllm"

[tool.llamahub.class_authors]
GithubLLM = "Josephrp"

[tool.mypy]
disallow_untyped_defs = true
# Remove venv skip when integrated with pre-commit
Expand All @@ -17,7 +24,7 @@ ignore_missing_imports = true
python_version = "3.8"

[tool.poetry]
authors = ["Your Name <you@example.com>"]
authors = ["Tonic <tonic@tonic-ai.com>"]
description = "llama-index llms githubllm integration"
license = "MIT"
name = "llama-index-llms-githubllm"
Expand Down

0 comments on commit 5110795

Please sign in to comment.