-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #282 from whylabs/asset-downloader
Asset downloader
- Loading branch information
Showing
13 changed files
with
314 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# pyright: reportUnknownMemberType=none, reportUnknownVariableType=none | ||
import logging | ||
import os | ||
import zipfile | ||
from dataclasses import dataclass | ||
from typing import cast | ||
|
||
import requests | ||
import whylabs_client | ||
from tenacity import retry, stop_after_attempt, wait_exponential_jitter | ||
from whylabs_client.api.assets_api import AssetsApi | ||
from whylabs_client.model.get_asset_response import GetAssetResponse | ||
|
||
from langkit.config import LANGKIT_CACHE | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
configuration = whylabs_client.Configuration(host="https://api.whylabsapp.com") | ||
configuration.api_key["ApiKeyAuth"] = os.environ["WHYLABS_API_KEY"] | ||
configuration.discard_unknown_keys = True | ||
|
||
client = whylabs_client.ApiClient(configuration) | ||
assets_api = AssetsApi(client) | ||
|
||
|
||
@dataclass | ||
class AssetPath: | ||
asset_id: str | ||
tag: str | ||
zip_path: str | ||
extract_path: str | ||
|
||
|
||
def _get_asset_path(asset_id: str, tag: str = "0") -> AssetPath: | ||
return AssetPath( | ||
asset_id=asset_id, | ||
tag=tag, | ||
zip_path=f"{LANGKIT_CACHE}/assets/{asset_id}/{tag}/{asset_id}.zip", | ||
extract_path=f"{LANGKIT_CACHE}/assets/{asset_id}/{tag}/{asset_id}/", | ||
) | ||
|
||
|
||
def _is_extracted(asset_id: str, tag: str = "0") -> bool: | ||
asset_path = _get_asset_path(asset_id, tag) | ||
if not os.path.exists(asset_path.zip_path): | ||
return False | ||
|
||
with zipfile.ZipFile(asset_path.zip_path, "r") as zip_ref: | ||
zip_names = set(zip_ref.namelist()) | ||
extract_names = set(os.listdir(asset_path.extract_path)) | ||
return zip_names.issubset(extract_names) | ||
|
||
|
||
def _extract_asset(asset_id: str, tag: str = "0"): | ||
asset_path = _get_asset_path(asset_id, tag) | ||
with zipfile.ZipFile(asset_path.zip_path, "r") as zip_ref: | ||
zip_ref.extractall(asset_path.extract_path) | ||
|
||
|
||
def _is_zip_file(file_path: str) -> bool: | ||
try: | ||
with zipfile.ZipFile(file_path, "r"): | ||
return True | ||
except zipfile.BadZipFile: | ||
return False | ||
|
||
|
||
@retry(stop=stop_after_attempt(3), wait=wait_exponential_jitter(max=5)) | ||
def _download_asset(asset_id: str, tag: str = "0"): | ||
asset_path = _get_asset_path(asset_id, tag) | ||
response: GetAssetResponse = cast(GetAssetResponse, assets_api.get_asset(asset_id)) | ||
url = cast(str, response.download_url) | ||
os.makedirs(os.path.dirname(asset_path.zip_path), exist_ok=True) | ||
r = requests.get(url, stream=True) | ||
with open(asset_path.zip_path, "wb") as f: | ||
for chunk in r.iter_content(chunk_size=1024): | ||
f.write(chunk) | ||
|
||
if not _is_zip_file(asset_path.zip_path): | ||
os.remove(asset_path.zip_path) | ||
raise ValueError(f"Downloaded file {asset_path.zip_path} is not a zip file") | ||
|
||
|
||
def get_asset(asset_id: str, tag: str = "0"): | ||
asset_path = _get_asset_path(asset_id, tag) | ||
if _is_extracted(asset_id, tag): | ||
logger.info(f"Asset {asset_id} with tag {tag} already downloaded and extracted") | ||
return asset_path.extract_path | ||
|
||
if not os.path.exists(asset_path.zip_path): | ||
logger.info(f"Downloading asset {asset_id} with tag {tag} to {asset_path.zip_path}") | ||
_download_asset(asset_id, tag) | ||
|
||
logger.info(f"Extracting asset {asset_id} with tag {tag}") | ||
_extract_asset(asset_id, tag) | ||
return asset_path.extract_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# pyright: reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownParameterType=false | ||
from enum import Enum | ||
from functools import lru_cache | ||
from typing import Any, List, Tuple, cast | ||
|
||
import numpy as np | ||
import onnxruntime as ort # pyright: ignore[reportMissingImports] | ||
import torch | ||
from transformers import BertTokenizerFast | ||
|
||
from langkit.asset_downloader import get_asset | ||
from langkit.metrics.embeddings_types import EmbeddingEncoder | ||
|
||
|
||
@lru_cache | ||
def _get_inference_session(onnx_file_path: str): | ||
return ort.InferenceSession(onnx_file_path, providers=["CPUExecutionProvider"]) # pyright: ignore[reportUnknownArgumentType] | ||
|
||
|
||
class TransformerModel(Enum): | ||
AllMiniLM = ("all-MiniLM-L6-v2", "0") | ||
|
||
def get_model_path(self): | ||
name, tag = self.value | ||
return f"{get_asset(name, tag)}/{name}.onnx" | ||
|
||
|
||
class OnnxSentenceTransformer(EmbeddingEncoder): | ||
def __init__(self, model: TransformerModel): | ||
self._tokenizer: BertTokenizerFast = cast(BertTokenizerFast, BertTokenizerFast.from_pretrained("bert-base-uncased")) | ||
self._model = model | ||
self._session = _get_inference_session(model.get_model_path()) | ||
|
||
def encode(self, text: Tuple[str, ...]) -> "torch.Tensor": | ||
# Pre-truncate the inputs to the model length for better performance | ||
max_length_in_chars = self._tokenizer.model_max_length * 5 # approx limit | ||
truncated_text = tuple(content[:max_length_in_chars] for content in text) | ||
model_inputs = self._tokenizer.batch_encode_plus(list(truncated_text), return_tensors="pt", padding=True, truncation=True) | ||
|
||
input_tensor: torch.Tensor = cast(torch.Tensor, model_inputs["input_ids"]) | ||
inputs_onnx = {"input_ids": input_tensor.cpu().numpy()} | ||
attention_mask: torch.Tensor = cast(torch.Tensor, model_inputs["attention_mask"]) | ||
inputs_onnx["attention_mask"] = attention_mask.cpu().detach().numpy().astype(np.float32) | ||
onnx_output: List['np.ndarray["Any", "Any"]'] = cast(List['np.ndarray["Any", "Any"]'], self._session.run(None, inputs_onnx)) | ||
embedding = OnnxSentenceTransformer.mean_pooling(onnx_output=onnx_output, attention_mask=attention_mask) | ||
return embedding[0] | ||
|
||
@staticmethod | ||
def mean_pooling( | ||
onnx_output: List['np.ndarray["Any", "Any"]'], attention_mask: torch.Tensor | ||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: | ||
token_embeddings = torch.from_numpy(onnx_output[0]) | ||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()) | ||
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) | ||
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | ||
return sum_embeddings / sum_mask, input_mask_expanded, sum_mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.