Skip to content

Commit

Permalink
add better typing
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Nov 16, 2024
1 parent 3efc07c commit 80d65be
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 26 deletions.
2 changes: 1 addition & 1 deletion libs/embed_package/embed/_infer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from concurrent.futures import Future
from typing import Collection, Literal, Union

from infinity_emb import EngineArgs, SyncEngineArray # type: ignore
from infinity_emb import EngineArgs, SyncEngineArray
from infinity_emb.infinity_server import AutoPadding

__all__ = ["BatchedInference"]
Expand Down
6 changes: 3 additions & 3 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class _OpenAIEmbeddingInput_Text(_OpenAIEmbeddingInput):
),
Annotated[str, INPUT_STRING],
]
modality: Literal[Modality.text] = Modality.text # type: ignore
modality: Literal[Modality.text] = Modality.text


class _OpenAIEmbeddingInput_URI(_OpenAIEmbeddingInput):
Expand All @@ -82,11 +82,11 @@ class _OpenAIEmbeddingInput_URI(_OpenAIEmbeddingInput):


class OpenAIEmbeddingInput_Audio(_OpenAIEmbeddingInput_URI):
modality: Literal[Modality.audio] = Modality.audio # type: ignore
modality: Literal[Modality.audio] = Modality.audio


class OpenAIEmbeddingInput_Image(_OpenAIEmbeddingInput_URI):
modality: Literal[Modality.image] = Modality.image # type: ignore
modality: Literal[Modality.image] = Modality.image


def get_modality(obj: dict) -> str:
Expand Down
4 changes: 2 additions & 2 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,9 +511,9 @@ def spawn(self) -> tuple[set[ModelCapabilites], dict]:

extras = {}
if hasattr(self._model, "sampling_rate"):
extras["sampling_rate"] = self._model.sampling_rate # type: ignore
extras["sampling_rate"] = self._model.sampling_rate

return self._model.capabilities, extras # type: ignore
return self._model.capabilities, extras

@property
def model(self):
Expand Down
6 changes: 3 additions & 3 deletions libs/infinity_emb/infinity_emb/inference/select_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)

if TYPE_CHECKING:
from infinity_emb.transformer.abstract import CallableReturningBaseTypeHint, BaseTypeHint
from infinity_emb.transformer.abstract import BaseTypeHint # , CallableReturningBaseTypeHint
from infinity_emb.args import (
EngineArgs,
)
Expand Down Expand Up @@ -91,7 +91,7 @@ def _get_engine_replica(unloaded_engine, engine_args, device_map) -> "BaseTypeHi

def select_model(
engine_args: "EngineArgs",
) -> list["CallableReturningBaseTypeHint"]:
) -> list[partial["BaseTypeHint"]]:
"""based on engine args, fully instantiates the Engine."""
logger.info(
f"model=`{engine_args.model_name_or_path}` selected, "
Expand All @@ -108,4 +108,4 @@ def select_model(
)
assert len(engine_replicas) > 0, "No engine replicas were loaded"

return engine_replicas # type: ignore
return engine_replicas
23 changes: 12 additions & 11 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,20 @@ def create_server(

@asynccontextmanager
async def lifespan(app: FastAPI):
instrumentator.expose(app) # type: ignore
instrumentator.expose(app)
logger.info(
f"Creating {len(engine_args_list)} engines: engines={[e.served_model_name for e in engine_args_list]}"
)
telemetry_log_info()
app.engine_array = AsyncEngineArray.from_args(engine_args_list) # type: ignore
engine_array = AsyncEngineArray.from_args(engine_args_list)
th = threading.Thread(
target=send_telemetry_start,
args=(engine_args_list, [{} for e in app.engine_array]), # type: ignore
args=(engine_args_list, [{} for e in engine_array]),
)
th.daemon = True
th.start()
# start in a threadpool
await app.engine_array.astart() # type: ignore
await engine_array.astart()

logger.info(
docs.startup_message(
Expand All @@ -120,8 +120,9 @@ async def kill_later(seconds: int):
logger.info(f"Preloaded configuration successfully. {engine_args_list} " " -> exit .")
asyncio.create_task(kill_later(3))

app.engine_array = engine_array # type: ignore
yield
await app.engine_array.astop() # type: ignore
await engine_array.astop()
# shutdown!

app = FastAPI(
Expand Down Expand Up @@ -691,7 +692,7 @@ def v1(
device: Device = MANAGER.device[0], # type: ignore
lengths_via_tokenize: bool = MANAGER.lengths_via_tokenize[0],
dtype: Dtype = MANAGER.dtype[0], # type: ignore
embedding_dtype: EmbeddingDtype = EmbeddingDtype.default_value(), # type: ignore
embedding_dtype: EmbeddingDtype = EmbeddingDtype.default_value(),
pooling_method: PoolingMethod = MANAGER.pooling_method[0], # type: ignore
compile: bool = MANAGER.compile[0],
bettertransformer: bool = MANAGER.bettertransformer[0],
Expand All @@ -701,7 +702,7 @@ def v1(
url_prefix: str = MANAGER.url_prefix,
host: str = MANAGER.host,
port: int = MANAGER.port,
log_level: UVICORN_LOG_LEVELS = MANAGER.log_level, # type: ignore
log_level: UVICORN_LOG_LEVELS = MANAGER.log_level,
):
"""Infinity API ♾️ cli v1 - deprecated, consider use cli v2 via `infinity_emb v2`."""
if api_key:
Expand All @@ -719,9 +720,9 @@ def v1(
time.sleep(1)
v2(
model_id=[model_name_or_path],
served_model_name=[served_model_name], # type: ignore
served_model_name=[served_model_name],
batch_size=[batch_size],
revision=[revision], # type: ignore
revision=[revision],
trust_remote_code=[trust_remote_code],
engine=[engine],
dtype=[dtype],
Expand All @@ -732,7 +733,7 @@ def v1(
lengths_via_tokenize=[lengths_via_tokenize],
compile=[compile],
bettertransformer=[bettertransformer],
embedding_dtype=[EmbeddingDtype.float32], # set to float32
embedding_dtype=[EmbeddingDtype.float32],
# unique kwargs
preload_only=preload_only,
url_prefix=url_prefix,
Expand Down Expand Up @@ -846,7 +847,7 @@ def v2(
),
log_level: UVICORN_LOG_LEVELS = typer.Option(
**_construct("log_level"), help="console log level."
), # type: ignore
),
permissive_cors: bool = typer.Option(
**_construct("permissive_cors"), help="whether to allow permissive cors."
),
Expand Down
2 changes: 1 addition & 1 deletion libs/infinity_emb/infinity_emb/transformer/audio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import aiohttp

if CHECK_SOUNDFILE.is_available:
import soundfile as sf # type: ignore
import soundfile as sf # type: ignore[import-untyped]


async def resolve_audio(
Expand Down
2 changes: 1 addition & 1 deletion libs/infinity_emb/infinity_emb/transformer/embedder/ct2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Module: # type: ignore[no-redef]


if CHECK_CTRANSLATE2.is_available:
import ctranslate2 # type: ignore
import ctranslate2 # type: ignore[import-untyped]


class CT2SentenceTransformer(SentenceTransformerPatched):
Expand Down
4 changes: 2 additions & 2 deletions libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def encode_core(self, input_dict: dict[str, np.ndarray]) -> dict:
}

@quant_embedding_decorator()
def encode_post(self, embedding: dict) -> EmbeddingReturnType:
embedding = self.pooling( # type: ignore
def encode_post(self, embedding: dict[str, "torch.Tensor"]) -> EmbeddingReturnType:
embedding = self.pooling(
embedding["token_embeddings"].numpy(), embedding["attention_mask"].numpy()
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING, Any, Union

import numpy as np
import requests # type: ignore
import requests # type: ignore[import-untyped]

from infinity_emb._optional_imports import CHECK_SENTENCE_TRANSFORMERS, CHECK_TORCH
from infinity_emb.env import MANAGER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
if CHECK_TORCH.is_available:
import torch
if CHECK_TRANSFORMERS.is_available:
from transformers import AutoConfig, AutoModel, AutoProcessor # type: ignore
from transformers import AutoConfig, AutoModel, AutoProcessor # type: ignore[import-untyped]
if CHECK_PIL.is_available:
from PIL import Image

Expand Down

0 comments on commit 80d65be

Please sign in to comment.