diff --git a/docs/docs/examples/llm/mymagic.ipynb b/docs/docs/examples/llm/mymagic.ipynb index 81a187cbf4abc..83fa89fe6fdac 100644 --- a/docs/docs/examples/llm/mymagic.ipynb +++ b/docs/docs/examples/llm/mymagic.ipynb @@ -69,6 +69,7 @@ " region=\"your-bucket-region\",\n", " return_output=False, # Whether you want MyMagic API to return the output json\n", " input_json_file=None, # name of the input file (stored on the bucket)\n", + " list_inputs=None, # Option to provide inputs as a list in case of small batch\n", " structured_output=None, # json schema of the output\n", ")" ] @@ -77,7 +78,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Note: if return_output is set True above, max_tokens should be set to atleast 100 " + "Note: if return_output is set True above, max_tokens should be set to at least 100 " ] }, { @@ -129,7 +130,7 @@ "async def main():\n", " response = await llm.acomplete(\n", " question=\"your-question\",\n", - " model=\"chhoose-model\", # currently we support mistral7b, llama7b, mixtral8x7,codellama70b, llama70b, more to come...\n", + " model=\"choose-model\", # supported models constantly updated and are listed at docs.mymagic.ai\n", " max_tokens=5, # number of tokens to generate, default is 10\n", " )\n", "\n", diff --git a/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py b/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py index 8099b2f48cdee..5f72908daeef7 100644 --- a/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py +++ b/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import httpx import asyncio import requests @@ -23,79 +23,105 @@ class MyMagicAI(LLM): `pip install llama-index-llms-mymagic` ```python - from llama_index.llms.mistralai import MistralAI - - llm = MistralAI(model="mistral7b", api_key="") + from llama_index.llms.mymagic import MyMagicAI + + llm = MyMagicAI( + api_key="your-api-key", + storage_provider="s3", # s3, gcs + bucket_name="your-bucket-name", + list_inputs="your list of inputs if you choose to pass directly", + session="your-session-name", # files should be located in this folder on which batch inference will be run + role_arn="your-role-arn", + system_prompt="your-system-prompt", + region="your-bucket-region", + return_output=False, # Whether you want MyMagic API to return the output json + input_json_file=None, # name of the input file (stored on the bucket) + structured_output=None, # json schema of the output + ) - resp = llm.complete("Paul Graham is ") + resp = llm.complete( + question="your-question", + model="choose-model", # check models at + max_tokens=5, # number of tokens to generate, default is 10 + ) print(resp) ``` """ - base_url_template: str = "https://{model}.mymagic.ai" + base_url_template: str = "https://fastapi.mymagic.ai" + completion_url: str = f"{base_url_template}/v1/completions" + status_url: str = f"{base_url_template}/get_result" + api_key: str = None - model: str = Field(default="mistral7b", description="The MyMagicAI model to use.") - max_tokens: int = Field( - default=10, description="The maximum number of tokens to generate." + list_inputs: Optional[List[str]] = Field( + None, + description="If user chooses to provide list of inputs to the model instead of specifying in a storage bucket.", ) - question = Field(default="", description="The user question.") storage_provider: str = Field( - default="gcs", description="The storage provider to use." + default=None, description="The storage provider to use." ) bucket_name: str = Field( - default="your-bucket-name", + default=None, description="The bucket name where the data is stored.", ) session: str = Field( - default="test-session", + default=None, description="The session to use. This is a subfolder in the bucket where your data is located.", ) role_arn: Optional[str] = Field( None, description="ARN for role assumption in AWS S3." ) - system_prompt: str = Field( + system_prompt: Optional[str] = Field( default="Answer the question based only on the given content. Do not give explanations or examples. Do not continue generating more text after the answer.", description="The system prompt to use.", ) - question_data: Dict[str, Any] = Field( - default_factory=dict, description="The data to send to the MyMagicAI API." - ) region: Optional[str] = Field( "eu-west-2", description="The region the bucket is in. Only used for AWS S3." ) - return_output: Optional[bool] = Field( - False, description="Whether MyMagic API should return the output json" - ) - input_json_file: Optional[str] = None + input_json_file: Optional[str] = Field( + None, description="Should the input be read from a single json file?" + ) structured_output: Optional[Dict[str, Any]] = Field( None, description="User-defined structure for the response output" ) + model: str = Field(default="mixtral8x7", description="The MyMagicAI model to use.") + max_tokens: int = Field( + default=10, description="The maximum number of tokens to generate." + ) + question = Field(default="", description="The user question.") + question_data: Dict[str, Any] = Field( + default_factory=dict, description="The data to send to the MyMagicAI API." + ) + return_output: Optional[bool] = Field( + False, description="Whether MyMagic API should return the output json" + ) def __init__( self, api_key: str, - storage_provider: str, - bucket_name: str, - session: str, - system_prompt: Optional[str], - role_arn: Optional[str] = None, - region: Optional[str] = None, - return_output: Optional[bool] = False, + storage_provider: Optional[str] = None, input_json_file: Optional[str] = None, structured_output: Optional[Dict[str, Any]] = None, + return_output: Optional[bool] = False, + list_inputs: Optional[List[str]] = None, + role_arn: Optional[str] = None, + region: Optional[str] = "eu-west-2", + session: str = None, + bucket_name: Optional[str] = None, + system_prompt: Optional[str] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.return_output = return_output + self.api_key = api_key self.question_data = { + "list_inputs": list_inputs, "storage_provider": storage_provider, "bucket_name": bucket_name, - "personal_access_token": api_key, "session": session, - "max_tokens": self.max_tokens, "role_arn": role_arn, "system_prompt": system_prompt, "region": region, @@ -108,38 +134,27 @@ def __init__( def class_name(cls) -> str: return "MyMagicAI" - def _construct_url(self, model: str) -> str: - """Constructs the API endpoint URL based on the specified model.""" - return self.base_url_template.format(model=model) - async def _submit_question(self, question_data: Dict[str, Any]) -> Dict[str, Any]: timeout_config = httpx.Timeout(600.0, connect=60.0) + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } async with httpx.AsyncClient(timeout=timeout_config) as client: - url = f"{self._construct_url(self.model)}/submit_question" - resp = await client.post(url, json=question_data) + resp = await client.post( + self.completion_url, + json=question_data, + headers=headers, + ) resp.raise_for_status() return resp.json() - def _submit_question_sync(self, question_data: Dict[str, Any]) -> Dict[str, Any]: - """Submits a question to the model synchronously.""" - url = f"{self._construct_url(self.model)}/submit_question" - resp = requests.post(url, json=question_data) - resp.raise_for_status() - return resp.json() - - def _get_result_sync(self, task_id: str) -> Dict[str, Any]: - """Polls for the result of a task synchronously.""" - url = f"{self._construct_url(self.model)}/get_result/{task_id}" - response = requests.get(url) - response.raise_for_status() - return response.json() - async def _get_result(self, task_id: str) -> Dict[str, Any]: - async with httpx.AsyncClient() as client: - resp = await client.get( - f"{self._construct_url(self.model)}/get_result/{task_id}" - ) + url = f"{self.status_url}/{task_id}" + timeout_config = httpx.Timeout(600.0, connect=60.0) + async with httpx.AsyncClient(timeout=timeout_config) as client: + resp = await client.get(url) resp.raise_for_status() return resp.json() @@ -151,11 +166,10 @@ async def acomplete( poll_interval: float = 1.0, ) -> CompletionResponse: self.question_data["question"] = question - self.model = self.question_data["model"] = model or self.model + self.question_data["model"] = model or self.model self.max_tokens = self.question_data["max_tokens"] = ( max_tokens or self.max_tokens ) - task_response = await self._submit_question(self.question_data) if self.return_output: @@ -168,6 +182,27 @@ async def acomplete( return result await asyncio.sleep(poll_interval) + def _submit_question_sync(self, question_data: Dict[str, Any]) -> Dict[str, Any]: + """Submits a question to the model synchronously.""" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + resp = requests.post( + self.completion_url, + json=question_data, + headers=headers, + ) + resp.raise_for_status() + return resp.json() + + def _get_result_sync(self, task_id: str) -> Dict[str, Any]: + """Polls for the result of a task synchronously.""" + url = f"{self.status_url}/{task_id}" + response = requests.get(url, timeout=600.0) + response.raise_for_status() + return response.json() + def complete( self, question: str, @@ -176,11 +211,10 @@ def complete( poll_interval: float = 1.0, ) -> CompletionResponse: self.question_data["question"] = question - self.model = self.question_data["model"] = model or self.model + self.question_data["model"] = model or self.model self.max_tokens = self.question_data["max_tokens"] = ( max_tokens or self.max_tokens ) - task_response = self._submit_question_sync(self.question_data) if self.return_output: return task_response diff --git a/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml index ac3c1348a1984..c2a87606ab749 100644 --- a/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-mymagic" readme = "README.md" -version = "0.1.6" +version = "0.1.7" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"