Skip to content

Commit

Permalink
Update injection metric to use an onnx model version
Browse files Browse the repository at this point in the history
  • Loading branch information
naddeoa committed Apr 5, 2024
1 parent 87bb683 commit a9c7327
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 33 deletions.
4 changes: 4 additions & 0 deletions .github/actions/python-build/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ inputs:
type:
description: python_version
required: true
whylabs_api_key:
required: true

runs:
using: "composite"
Expand Down Expand Up @@ -34,6 +36,8 @@ runs:

- name: Run test
shell: bash
env:
WHYLABS_API_KEY: ${{ inputs.whylabs_api_key }}
run: make test

- name: Make dists
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
uses: ./.github/actions/python-build
with:
python_version: ${{ matrix.python_version }}

whylabs_api_key: ${{ secrets.WHYLABS_API_KEY }}

enforce_cache_constraint:
name: Make sure the cache constraint logic works
Expand Down Expand Up @@ -84,4 +84,6 @@ jobs:
uses: docker/setup-buildx-action@v3

- name: Run test
env:
WHYLABS_API_KEY: ${{ secrets.WHYLABS_API_KEY }}
run: make test-cache-constraint
2 changes: 2 additions & 0 deletions Dockerfile.cache_test
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ RUN chmod -R a+rw $CONTAINER_CACHE_BASE
## Install/build pip dependencies
##
USER root
ARG WHYLABS_API_KEY

RUN apt-get update && apt-get install -y curl build-essential
USER whylabs
# Install poetry
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ test: ## Run the tests
poetry run pytest tests -vvv -o log_level=INFO -o log_cli=true

test-cache-constraint:
docker build -f ./Dockerfile.cache_test . -t langkit_cache_test
docker build -f ./Dockerfile.cache_test --build-arg WHYLABS_API_KEY=$(WHYLABS_API_KEY) . -t langkit_cache_test

load-test:
poetry run pytest -vvv langkit/tests -o log_level=WARN -o log_cli=true --load
Expand Down
2 changes: 1 addition & 1 deletion langkit/core/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def to_numpy(self, column_name: str) -> np.ndarray[Any, Any]:

def to_list(self, column_name: str) -> List[Any]:
if column_name not in self.text:
raise ValueError(f"Column {column_name} not found in {self.text}")
raise KeyError(f"Column {column_name} not found in {self.text}")

if isinstance(self.text, pd.DataFrame):
col = cast("pd.Series[Any]", self.text[column_name])
Expand Down
3 changes: 2 additions & 1 deletion langkit/metrics/embeddings_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def encode(self, text: Tuple[str, ...]) -> "torch.Tensor":


class EmbeddingEncoder(Protocol):
def encode(self, text: Tuple[str, ...]) -> "torch.Tensor": ...
def encode(self, text: Tuple[str, ...]) -> "torch.Tensor":
...


class CachingEmbeddingEncoder(EmbeddingEncoder):
Expand Down
4 changes: 2 additions & 2 deletions langkit/metrics/embeddings_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch
import torch.nn.functional as F

from langkit.metrics.embeddings_types import TransformerEmbeddingAdapter
from langkit.metrics.embeddings_types import EmbeddingEncoder


def compute_embedding_similarity(encoder: TransformerEmbeddingAdapter, _in: List[str], _out: List[str]) -> torch.Tensor:
def compute_embedding_similarity(encoder: EmbeddingEncoder, _in: List[str], _out: List[str]) -> torch.Tensor:
in_encoded = torch.as_tensor(encoder.encode(tuple(_in)))
out_encoded = torch.as_tensor(encoder.encode(tuple(_out)))
return F.cosine_similarity(in_encoded, out_encoded, dim=1)
6 changes: 3 additions & 3 deletions langkit/metrics/injections.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def udf(text: pd.DataFrame) -> SingleMetricResult:
input_series: "pd.Series[str]" = cast("pd.Series[str]", text[column_name])

if onnx:
_transformer = embedding_adapter() # onnx
target_embeddings: npt.NDArray[np.float32] = _transformer.encode(tuple(input_series)).numpy() # onnx
_transformer = embedding_adapter()
target_embeddings: npt.NDArray[np.float32] = _transformer.encode(tuple(input_series)).numpy()
else:
_transformer = sentence_transformer()
target_embeddings: npt.NDArray[np.float32] = _transformer.encode(tuple(input_series), show_progress_bar=False)
target_embeddings: npt.NDArray[np.float32] = _transformer.encode(list(input_series), show_progress_bar=False) # pyright: ignore[reportAssignmentType, reportUnknownMemberType]

target_norms = target_embeddings / np.linalg.norm(target_embeddings, axis=1, keepdims=True)
cosine_similarities = np.dot(_embeddings, target_norms.T)
Expand Down
29 changes: 6 additions & 23 deletions langkit/onnx_encoder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# pyright: reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownParameterType=false
import time
from enum import Enum
from functools import lru_cache
from os import environ
from typing import Any, List, Tuple, cast

import numpy as np
import onnxruntime as ort # pyright: ignore[reportMissingImports]
import torch
from psutil import cpu_count
from transformers import BertTokenizerFast

from langkit.asset_downloader import get_asset
Expand All @@ -17,19 +14,7 @@

@lru_cache
def _get_inference_session(onnx_file_path: str):
cpus = cpu_count(logical=True)
# environ["OMP_NUM_THREADS"] = str(cpus)
# environ["OMP_WAIT_POLICY"] = "ACTIVE"
sess_opts: ort.SessionOptions = ort.SessionOptions()
# sess_opts.enable_cpu_mem_arena = True
# sess_opts.inter_op_num_threads = cpus
# sess_opts.intra_op_num_threads = 1
# sess_opts.execution_mode = ort.ExecutionMode.ORT_PARALLEL

# sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# sess_opts.enable_mem_pattern = True

return ort.InferenceSession(onnx_file_path, providers=["CPUExecutionProvider"], sess_options=sess_opts) # pyright: ignore[reportUnknownArgumentType]
return ort.InferenceSession(onnx_file_path, providers=["CPUExecutionProvider"]) # pyright: ignore[reportUnknownArgumentType]


class TransformerModel(Enum):
Expand All @@ -40,25 +25,23 @@ def get_model_path(self):
return f"{get_asset(name, tag)}/{name}.onnx"


# _times: List[float] = []


class OnnxSentenceTransformer(EmbeddingEncoder):
def __init__(self, model: TransformerModel):
self._tokenizer: BertTokenizerFast = cast(BertTokenizerFast, BertTokenizerFast.from_pretrained("bert-base-uncased"))
self._model = model
self._session = _get_inference_session(model.get_model_path())

def encode(self, text: Tuple[str, ...]) -> "torch.Tensor":
model_inputs = self._tokenizer.batch_encode_plus(list(text), return_tensors="pt", padding=True, truncation=True)
# Pre-truncate the inputs to the model length for better performance
max_length_in_chars = self._tokenizer.model_max_length * 5 # approx limit
truncated_text = tuple(content[:max_length_in_chars] for content in text)
model_inputs = self._tokenizer.batch_encode_plus(list(truncated_text), return_tensors="pt", padding=True, truncation=True)

input_tensor: torch.Tensor = cast(torch.Tensor, model_inputs["input_ids"])
inputs_onnx = {"input_ids": input_tensor.cpu().numpy()}
attention_mask: torch.Tensor = cast(torch.Tensor, model_inputs["attention_mask"])
inputs_onnx["attention_mask"] = attention_mask.cpu().detach().numpy().astype(np.float32)
start_time = time.perf_counter()
onnx_output: List['np.ndarray["Any", "Any"]'] = cast(List['np.ndarray["Any", "Any"]'], self._session.run(None, inputs_onnx))
# _times.append(time.perf_counter() - start_time)
# print(f"Average time: {sum(_times) / len(_times)}")
embedding = OnnxSentenceTransformer.mean_pooling(onnx_output=onnx_output, attention_mask=attention_mask)
return embedding[0]

Expand Down
2 changes: 1 addition & 1 deletion langkit/transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import lru_cache
import torch
from typing import Tuple

import torch
from sentence_transformers import SentenceTransformer

from langkit.metrics.embeddings_types import CachingEmbeddingEncoder, EmbeddingEncoder
Expand Down

0 comments on commit a9c7327

Please sign in to comment.