Skip to content

Commit

Permalink
Merge pull request #282 from whylabs/asset-downloader
Browse files Browse the repository at this point in the history
Asset downloader
  • Loading branch information
naddeoa authored Apr 5, 2024
2 parents 88ac4ec + a9c7327 commit 78fcbdb
Show file tree
Hide file tree
Showing 13 changed files with 314 additions and 24 deletions.
4 changes: 4 additions & 0 deletions .github/actions/python-build/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ inputs:
type:
description: python_version
required: true
whylabs_api_key:
required: true

runs:
using: "composite"
Expand Down Expand Up @@ -34,6 +36,8 @@ runs:

- name: Run test
shell: bash
env:
WHYLABS_API_KEY: ${{ inputs.whylabs_api_key }}
run: make test

- name: Make dists
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
uses: ./.github/actions/python-build
with:
python_version: ${{ matrix.python_version }}

whylabs_api_key: ${{ secrets.WHYLABS_API_KEY }}

enforce_cache_constraint:
name: Make sure the cache constraint logic works
Expand Down Expand Up @@ -84,4 +84,6 @@ jobs:
uses: docker/setup-buildx-action@v3

- name: Run test
env:
WHYLABS_API_KEY: ${{ secrets.WHYLABS_API_KEY }}
run: make test-cache-constraint
2 changes: 2 additions & 0 deletions Dockerfile.cache_test
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ RUN chmod -R a+rw $CONTAINER_CACHE_BASE
## Install/build pip dependencies
##
USER root
ARG WHYLABS_API_KEY

RUN apt-get update && apt-get install -y curl build-essential
USER whylabs
# Install poetry
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ test: ## Run the tests
poetry run pytest tests -vvv -o log_level=INFO -o log_cli=true

test-cache-constraint:
docker build -f ./Dockerfile.cache_test . -t langkit_cache_test
docker build -f ./Dockerfile.cache_test --build-arg WHYLABS_API_KEY=$(WHYLABS_API_KEY) . -t langkit_cache_test

load-test:
poetry run pytest -vvv langkit/tests -o log_level=WARN -o log_cli=true --load
Expand Down
96 changes: 96 additions & 0 deletions langkit/asset_downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# pyright: reportUnknownMemberType=none, reportUnknownVariableType=none
import logging
import os
import zipfile
from dataclasses import dataclass
from typing import cast

import requests
import whylabs_client
from tenacity import retry, stop_after_attempt, wait_exponential_jitter
from whylabs_client.api.assets_api import AssetsApi
from whylabs_client.model.get_asset_response import GetAssetResponse

from langkit.config import LANGKIT_CACHE

logger = logging.getLogger(__name__)

configuration = whylabs_client.Configuration(host="https://api.whylabsapp.com")
configuration.api_key["ApiKeyAuth"] = os.environ["WHYLABS_API_KEY"]
configuration.discard_unknown_keys = True

client = whylabs_client.ApiClient(configuration)
assets_api = AssetsApi(client)


@dataclass
class AssetPath:
asset_id: str
tag: str
zip_path: str
extract_path: str


def _get_asset_path(asset_id: str, tag: str = "0") -> AssetPath:
return AssetPath(
asset_id=asset_id,
tag=tag,
zip_path=f"{LANGKIT_CACHE}/assets/{asset_id}/{tag}/{asset_id}.zip",
extract_path=f"{LANGKIT_CACHE}/assets/{asset_id}/{tag}/{asset_id}/",
)


def _is_extracted(asset_id: str, tag: str = "0") -> bool:
asset_path = _get_asset_path(asset_id, tag)
if not os.path.exists(asset_path.zip_path):
return False

with zipfile.ZipFile(asset_path.zip_path, "r") as zip_ref:
zip_names = set(zip_ref.namelist())
extract_names = set(os.listdir(asset_path.extract_path))
return zip_names.issubset(extract_names)


def _extract_asset(asset_id: str, tag: str = "0"):
asset_path = _get_asset_path(asset_id, tag)
with zipfile.ZipFile(asset_path.zip_path, "r") as zip_ref:
zip_ref.extractall(asset_path.extract_path)


def _is_zip_file(file_path: str) -> bool:
try:
with zipfile.ZipFile(file_path, "r"):
return True
except zipfile.BadZipFile:
return False


@retry(stop=stop_after_attempt(3), wait=wait_exponential_jitter(max=5))
def _download_asset(asset_id: str, tag: str = "0"):
asset_path = _get_asset_path(asset_id, tag)
response: GetAssetResponse = cast(GetAssetResponse, assets_api.get_asset(asset_id))
url = cast(str, response.download_url)
os.makedirs(os.path.dirname(asset_path.zip_path), exist_ok=True)
r = requests.get(url, stream=True)
with open(asset_path.zip_path, "wb") as f:
for chunk in r.iter_content(chunk_size=1024):
f.write(chunk)

if not _is_zip_file(asset_path.zip_path):
os.remove(asset_path.zip_path)
raise ValueError(f"Downloaded file {asset_path.zip_path} is not a zip file")


def get_asset(asset_id: str, tag: str = "0"):
asset_path = _get_asset_path(asset_id, tag)
if _is_extracted(asset_id, tag):
logger.info(f"Asset {asset_id} with tag {tag} already downloaded and extracted")
return asset_path.extract_path

if not os.path.exists(asset_path.zip_path):
logger.info(f"Downloading asset {asset_id} with tag {tag} to {asset_path.zip_path}")
_download_asset(asset_id, tag)

logger.info(f"Extracting asset {asset_id} with tag {tag}")
_extract_asset(asset_id, tag)
return asset_path.extract_path
2 changes: 1 addition & 1 deletion langkit/core/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def to_numpy(self, column_name: str) -> np.ndarray[Any, Any]:

def to_list(self, column_name: str) -> List[Any]:
if column_name not in self.text:
raise ValueError(f"Column {column_name} not found in {self.text}")
raise KeyError(f"Column {column_name} not found in {self.text}")

if isinstance(self.text, pd.DataFrame):
col = cast("pd.Series[Any]", self.text[column_name])
Expand Down
9 changes: 9 additions & 0 deletions langkit/metrics/embeddings_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,12 @@ def encode(self, text: Tuple[str, ...]) -> "torch.Tensor":
class EmbeddingEncoder(Protocol):
def encode(self, text: Tuple[str, ...]) -> "torch.Tensor":
...


class CachingEmbeddingEncoder(EmbeddingEncoder):
def __init__(self, transformer: EmbeddingEncoder):
self._transformer = transformer

@lru_cache(maxsize=6, typed=True)
def encode(self, text: Tuple[str, ...]) -> "torch.Tensor": # pyright: ignore[reportIncompatibleMethodOverride]
return self._transformer.encode(text) # type: ignore[no-any-return]
4 changes: 2 additions & 2 deletions langkit/metrics/embeddings_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch
import torch.nn.functional as F

from langkit.metrics.embeddings_types import TransformerEmbeddingAdapter
from langkit.metrics.embeddings_types import EmbeddingEncoder


def compute_embedding_similarity(encoder: TransformerEmbeddingAdapter, _in: List[str], _out: List[str]) -> torch.Tensor:
def compute_embedding_similarity(encoder: EmbeddingEncoder, _in: List[str], _out: List[str]) -> torch.Tensor:
in_encoded = torch.as_tensor(encoder.encode(tuple(_in)))
out_encoded = torch.as_tensor(encoder.encode(tuple(_out)))
return F.cosine_similarity(in_encoded, out_encoded, dim=1)
20 changes: 14 additions & 6 deletions langkit/metrics/injections.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from functools import lru_cache, partial
from logging import getLogger
from typing import Any, Sequence
from typing import Any, Sequence, cast

import numpy as np
import numpy.typing as npt
Expand All @@ -10,7 +10,7 @@
from langkit.config import LANGKIT_CACHE
from langkit.core.metric import Metric, SingleMetric, SingleMetricResult
from langkit.metrics.util import retry
from langkit.transformer import sentence_transformer
from langkit.transformer import embedding_adapter, sentence_transformer

logger = getLogger(__name__)

Expand Down Expand Up @@ -66,19 +66,27 @@ def _get_embeddings(version: str) -> "np.ndarray[Any, Any]":
return embeddings_norm


def injections_metric(column_name: str, version: str = "v2") -> Metric:
def injections_metric(column_name: str, version: str = "v2", onnx: bool = True) -> Metric:
def cache_assets():
_get_embeddings(version)

def init():
sentence_transformer()
embedding_adapter()

def udf(text: pd.DataFrame) -> SingleMetricResult:
if column_name not in text.columns:
raise ValueError(f"Injections: Column {column_name} not found in input dataframe")
_embeddings = _get_embeddings(version)
_transformer = sentence_transformer()
target_embeddings: npt.NDArray[np.float32] = _transformer.encode(text[column_name]) # type: ignore[reportUnknownMemberType]

input_series: "pd.Series[str]" = cast("pd.Series[str]", text[column_name])

if onnx:
_transformer = embedding_adapter()
target_embeddings: npt.NDArray[np.float32] = _transformer.encode(tuple(input_series)).numpy()
else:
_transformer = sentence_transformer()
target_embeddings: npt.NDArray[np.float32] = _transformer.encode(list(input_series), show_progress_bar=False) # pyright: ignore[reportAssignmentType, reportUnknownMemberType]

target_norms = target_embeddings / np.linalg.norm(target_embeddings, axis=1, keepdims=True)
cosine_similarities = np.dot(_embeddings, target_norms.T)
max_similarities = np.max(cosine_similarities, axis=0) # type: ignore[reportUnknownMemberType]
Expand Down
56 changes: 56 additions & 0 deletions langkit/onnx_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# pyright: reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownParameterType=false
from enum import Enum
from functools import lru_cache
from typing import Any, List, Tuple, cast

import numpy as np
import onnxruntime as ort # pyright: ignore[reportMissingImports]
import torch
from transformers import BertTokenizerFast

from langkit.asset_downloader import get_asset
from langkit.metrics.embeddings_types import EmbeddingEncoder


@lru_cache
def _get_inference_session(onnx_file_path: str):
return ort.InferenceSession(onnx_file_path, providers=["CPUExecutionProvider"]) # pyright: ignore[reportUnknownArgumentType]


class TransformerModel(Enum):
AllMiniLM = ("all-MiniLM-L6-v2", "0")

def get_model_path(self):
name, tag = self.value
return f"{get_asset(name, tag)}/{name}.onnx"


class OnnxSentenceTransformer(EmbeddingEncoder):
def __init__(self, model: TransformerModel):
self._tokenizer: BertTokenizerFast = cast(BertTokenizerFast, BertTokenizerFast.from_pretrained("bert-base-uncased"))
self._model = model
self._session = _get_inference_session(model.get_model_path())

def encode(self, text: Tuple[str, ...]) -> "torch.Tensor":
# Pre-truncate the inputs to the model length for better performance
max_length_in_chars = self._tokenizer.model_max_length * 5 # approx limit
truncated_text = tuple(content[:max_length_in_chars] for content in text)
model_inputs = self._tokenizer.batch_encode_plus(list(truncated_text), return_tensors="pt", padding=True, truncation=True)

input_tensor: torch.Tensor = cast(torch.Tensor, model_inputs["input_ids"])
inputs_onnx = {"input_ids": input_tensor.cpu().numpy()}
attention_mask: torch.Tensor = cast(torch.Tensor, model_inputs["attention_mask"])
inputs_onnx["attention_mask"] = attention_mask.cpu().detach().numpy().astype(np.float32)
onnx_output: List['np.ndarray["Any", "Any"]'] = cast(List['np.ndarray["Any", "Any"]'], self._session.run(None, inputs_onnx))
embedding = OnnxSentenceTransformer.mean_pooling(onnx_output=onnx_output, attention_mask=attention_mask)
return embedding[0]

@staticmethod
def mean_pooling(
onnx_output: List['np.ndarray["Any", "Any"]'], attention_mask: torch.Tensor
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
token_embeddings = torch.from_numpy(onnx_output[0])
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask, input_mask_expanded, sum_mask
7 changes: 4 additions & 3 deletions langkit/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch
from sentence_transformers import SentenceTransformer

from langkit.metrics.embeddings_types import TransformerEmbeddingAdapter
from langkit.metrics.embeddings_types import CachingEmbeddingEncoder, EmbeddingEncoder
from langkit.onnx_encoder import OnnxSentenceTransformer, TransformerModel


@lru_cache
Expand All @@ -24,5 +25,5 @@ def sentence_transformer(


@lru_cache
def embedding_adapter() -> TransformerEmbeddingAdapter:
return TransformerEmbeddingAdapter(sentence_transformer())
def embedding_adapter() -> EmbeddingEncoder:
return CachingEmbeddingEncoder(OnnxSentenceTransformer(TransformerModel.AllMiniLM))
Loading

0 comments on commit 78fcbdb

Please sign in to comment.