Skip to content

Commit

Permalink
Add an option to not use elastic agents for meta-reference inference (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinb authored Oct 18, 2024
1 parent be3c5c0 commit 33afd34
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@

class MetaReferenceInferenceConfig(BaseModel):
model: str = Field(
default="Llama3.1-8B-Instruct",
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1

# when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
# (including our testing code) who might be using llama-stack as a library.
create_distributed_process_group: bool = True

@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
Expand Down
34 changes: 27 additions & 7 deletions llama_stack/providers/impls/meta_reference/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)

from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator

# there's a single model parallel process running serving the model. for now,
Expand All @@ -36,8 +37,11 @@ def __init__(self, config: MetaReferenceInferenceConfig) -> None:

async def initialize(self) -> None:
print(f"Loading model `{self.model.descriptor()}`")
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
else:
self.generator = Llama.build(self.config)

async def register_model(self, model: ModelDef) -> None:
raise ValueError("Dynamic model registration is not supported")
Expand All @@ -51,7 +55,8 @@ async def list_models(self) -> List[ModelDef]:
]

async def shutdown(self) -> None:
self.generator.stop()
if self.config.create_distributed_process_group:
self.generator.stop()

def completion(
self,
Expand Down Expand Up @@ -99,8 +104,9 @@ def chat_completion(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)

if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")

if request.stream:
return self._stream_chat_completion(request)
Expand All @@ -110,7 +116,7 @@ def chat_completion(
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with SEMAPHORE:
def impl():
messages = chat_completion_request_to_messages(request)

tokens = []
Expand Down Expand Up @@ -154,10 +160,16 @@ async def _nonstream_chat_completion(
logprobs=logprobs if request.logprobs else None,
)

if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()

async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async with SEMAPHORE:
def impl():
messages = chat_completion_request_to_messages(request)

yield ChatCompletionResponseStreamChunk(
Expand Down Expand Up @@ -272,6 +284,14 @@ async def _stream_chat_completion(
)
)

if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x

async def embeddings(
self,
model: str,
Expand Down

0 comments on commit 33afd34

Please sign in to comment.