Skip to content

Commit

Permalink
local models (#233)
Browse files Browse the repository at this point in the history
* local models
* Update langkit/examples/Local_Models.ipynb

---------

Co-authored-by: felipe207 <[email protected]>
Co-authored-by: richard-rogers <[email protected]>
Co-authored-by: Jamie Broomall <[email protected]>
  • Loading branch information
4 people authored Mar 27, 2024
1 parent 21f726b commit a818af7
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 38 deletions.
11 changes: 11 additions & 0 deletions langkit/docs/modules.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ profile = why.log({"prompt":"What is the primary function of the mitochondria in
"response":"The Eiffel Tower is a renowned landmark in Paris, France"}, schema=text_schema).profile()
```

#### Configuration

- [Local model path configuration](https://github.com/whylabs/langkit/blob/main/langkit/examples/Local_Models.ipynb)
- [Custom Encoder configuration](https://github.com/whylabs/langkit/blob/main/langkit/examples/Custom_Encoder.ipynb)

### `response.relevance_to_prompt`

The `response.relevance_to_prompt` computed column will contain a similarity score between the prompt and response. The higher the score, the more relevant the response is to the prompt.
Expand Down Expand Up @@ -415,6 +420,8 @@ from langkit import themes
themes.init(theme_file_path="path/to/themes.json")
```

Users can also use local models with `themes`. See the [Local Model](https://github.com/whylabs/langkit/blob/main/langkit/examples/Local_Models.ipynb) example for more information.

### `jailbreaks`

This group gathers a set of known jailbreak examples.
Expand Down Expand Up @@ -482,3 +489,7 @@ results = extract({"prompt": "I hate you."})
```

For more information, see the [Toxicity Model Configuration](https://github.com/whylabs/langkit/blob/main/langkit/examples/Toxicity_Model_Configuration.ipynb) example.

Users can also pass a local model to `toxicity`. Currently, only `martin-ha/toxic-comment-model` is supported with local use. See the example in:

- [Local model path configuration](https://github.com/whylabs/langkit/blob/main/langkit/examples/Local_Models.ipynb)
209 changes: 209 additions & 0 deletions langkit/examples/Local_Models.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
">### 🚩 *Create a free WhyLabs account to complete this example!*<br> \n",
">*Did you know you can store, visualize, and monitor whylogs profiles with the [WhyLabs Observability Platform](https://whylabs.ai/whylabs-free-sign-up?utm_source=github&utm_medium=referral&utm_campaign=Local_Models)? Sign up for a [free WhyLabs account](https://whylabs.ai/whylogs-free-signup?utm_source=github&utm_medium=referral&utm_campaign=Local_Models) to leverage the power of whylogs and WhyLabs together!*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using Langkit with Local Models\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/whylabs/LanguageToolkit/blob/main/langkit/examples/Local_Models.ipynb)\n",
"\n",
"Some of the Langkit modules download models from the internet. This is not always possible, for example, when running in an environment without internet access. In this example, we will show how you can use Langkit with models stored locally."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's start by installing LangKit:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install langkit[all]==0.0.31 -q"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We're also assuming the existence of local models in specific folders, such as when downloading the models with the script below.\n",
"\n",
"Make sure you have git-lfs installed. If not, you can install it by running:\n",
"\n",
"`sudo apt-get install git-lfs`"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cloning into 'local-toxicity-model'...\n",
"remote: Enumerating objects: 40, done.\u001b[K\n",
"remote: Total 40 (delta 0), reused 0 (delta 0), pack-reused 40\u001b[K\n",
"Unpacking objects: 100% (40/40), 301.27 KiB | 414.00 KiB/s, done.\n",
"Cloning into 'local-sentence-transformers'...\n",
"remote: Enumerating objects: 49, done.\u001b[K\n",
"remote: Counting objects: 100% (3/3), done.\u001b[K\n",
"remote: Compressing objects: 100% (3/3), done.\u001b[K\n",
"remote: Total 49 (delta 0), reused 0 (delta 0), pack-reused 46\u001b[K\n",
"Unpacking objects: 100% (49/49), 316.57 KiB | 311.00 KiB/s, done.\n",
"Filtering content: 100% (3/3), 260.15 MiB | 16.47 MiB/s, done.\n"
]
}
],
"source": [
"!git clone https://huggingface.co/martin-ha/toxic-comment-model local-toxicity-model\n",
"!git clone https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 local-sentence-transformers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `martin-ha/toxic-comment-model` is the model currently used in `toxicity`, and `sentence-transformers/all-MiniLM-L6-v2` is used to generate embeddings in both `themes` and `input_output_modules`. We can pass the local paths when initializing the modules:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from langkit import themes\n",
"from langkit import toxicity\n",
"from langkit import input_output\n",
"\n",
"from langkit import LangKitConfig\n",
"\n",
"local_config = LangKitConfig(toxicity_model_path=\"local-toxicity-model\",\n",
" transformer_name=\"local-sentence-transformers\")\n",
"\n",
"toxicity.init(config=local_config)\n",
"themes.init(config=local_config)\n",
"input_output.init(config=local_config)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If, for example, we want a local version for the `llm_metrics` module, we also need to import `textstat`, `regexes`, and `sentiment`. `regexes` and `textstat` are lightweight models and don't require external artifacts, so we can use them in a network restricted environment. `sentiment`, however, downloads artifacts from the internet, so let's replace it with `vader_sentiment`, which will yield the same results as `sentiment`, with the benefit of not requiring downloading artifacts at runtime."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from langkit import regexes\n",
"from langkit import vader_sentiment\n",
"from langkit import textstat"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we should have an equivalent version of `llm_metrics` that doesn't require internet access. Let's check the results for a toy example:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'prompt': 'I like you. I love you',\n",
" 'response': 'thanks!',\n",
" 'prompt.jailbreak_similarity': 0.2522321939468384,\n",
" 'response.refusal_similarity': 0.1535428911447525,\n",
" 'prompt.toxicity': 0.006519913673400879,\n",
" 'response.toxicity': 0.0011597275733947754,\n",
" 'response.relevance_to_prompt': 0.23008441925048828,\n",
" 'prompt.has_patterns': None,\n",
" 'response.has_patterns': None,\n",
" 'prompt.vader_sentiment': 0.7717,\n",
" 'response.vader_sentiment': 0.4926,\n",
" 'prompt.flesch_reading_ease': 119.19,\n",
" 'response.flesch_reading_ease': 121.22,\n",
" 'prompt.automated_readability_index': -6.7,\n",
" 'response.automated_readability_index': 12.0,\n",
" 'prompt.aggregate_reading_level': 1.0,\n",
" 'response.aggregate_reading_level': 0.0,\n",
" 'prompt.syllable_count': 6,\n",
" 'response.syllable_count': 1,\n",
" 'prompt.lexicon_count': 6,\n",
" 'response.lexicon_count': 1,\n",
" 'prompt.sentence_count': 2,\n",
" 'response.sentence_count': 1,\n",
" 'prompt.character_count': 17,\n",
" 'response.character_count': 7,\n",
" 'prompt.letter_count': 16,\n",
" 'response.letter_count': 6,\n",
" 'prompt.polysyllable_count': 0,\n",
" 'response.polysyllable_count': 0,\n",
" 'prompt.monosyllable_count': 6,\n",
" 'response.monosyllable_count': 1,\n",
" 'prompt.difficult_words': 0,\n",
" 'response.difficult_words': 0}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from whylogs.experimental.core.udf_schema import udf_schema\n",
"from langkit import extract\n",
"\n",
"text_schema = udf_schema()\n",
"result = extract({\"prompt\":\"I like you. I love you\",\"response\":\"thanks!\"},schema=text_schema)\n",
"\n",
"result"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "langkit-rNdo63Yk-py3.8",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
6 changes: 3 additions & 3 deletions langkit/injections.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def init(
)

try:
np_embeddings = np.stack(harm_embeddings["sentence_embedding"].values).astype(
np.float32
)
array_list = [np.array(x) for x in harm_embeddings["sentence_embedding"].values]
np_embeddings = np.stack(array_list).astype(np.float32)

_embeddings_norm = np_embeddings / np.linalg.norm(
np_embeddings, axis=1, keepdims=True
)
Expand Down
13 changes: 9 additions & 4 deletions langkit/themes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,30 @@ def group_similarity(text: str, group):
raise ValueError("Must initialize a transformer before calling encode!")

text_embedding = _transformer_model.encode(text)
_cache_embeddings_map(group)
for embedding in _embeddings_map.get(group, []):
similarity = get_embeddings_similarity(text_embedding, embedding)
similarities.append(similarity)
return max(similarities) if similarities else None


def _map_embeddings():
global _embeddings_map
for group in _theme_groups:
def _cache_embeddings_map(group):
if group not in _embeddings_map:
_embeddings_map[group] = [
_transformer_model.encode(s) for s in _theme_groups.get(group, [])
]


def _clear_embeddings_map():
global _embeddings_map
_embeddings_map = {}


_registered = set()


def _register_theme_udfs():
global _registered
_map_embeddings()

for group in _theme_groups:
for column in [_prompt, _response]:
Expand Down Expand Up @@ -111,6 +115,7 @@ def init(
_theme_groups = load_themes(config.theme_file_path)
else:
_theme_groups = load_themes(theme_file_path)
_clear_embeddings_map()
_register_theme_udfs()


Expand Down
55 changes: 35 additions & 20 deletions langkit/toxicity.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from copy import deepcopy
from typing import Optional

from functools import lru_cache
from whylogs.experimental.core.udf_schema import register_dataset_udf
from langkit import LangKitConfig, lang_config, prompt_column, response_column

import os
import torch
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
TextClassificationPipeline,
)

_USE_CUDA = torch.cuda.is_available() and not bool(
os.environ.get("LANGKIT_NO_CUDA", False)
Expand All @@ -14,6 +18,27 @@

_prompt = prompt_column
_response = response_column


@lru_cache(maxsize=None)
def _get_tokenizer(model_path: str):
return AutoTokenizer.from_pretrained(model_path)


@lru_cache(maxsize=None)
def _get_model(model_path: str):
return AutoModelForSequenceClassification.from_pretrained(model_path)


@lru_cache(maxsize=None)
def _get_pipeline(model_path: str):
return TextClassificationPipeline(
model=_get_model(model_path),
tokenizer=_get_tokenizer(model_path),
device=_device,
)


_toxicity_model: Optional["ToxicityModel"] = None


Expand All @@ -34,21 +59,13 @@ def predict(self, text: str):

class ToxicCommentModel(ToxicityModel):
def __init__(self, model_path: str):
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
TextClassificationPipeline,
)

self.toxicity_tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.toxicity_pipeline = TextClassificationPipeline(
model=model, tokenizer=self.toxicity_tokenizer, device=_device
)
self.model_path = model_path

def predict(self, text: str) -> float:
result = self.toxicity_pipeline(
text, truncation=True, max_length=self.toxicity_tokenizer.model_max_length
toxicity_pipeline = _get_pipeline(self.model_path)
toxicity_tokenizer = _get_tokenizer(self.model_path)
result = toxicity_pipeline(
text, truncation=True, max_length=toxicity_tokenizer.model_max_length
)
return (
result[0]["score"]
Expand Down Expand Up @@ -76,16 +93,14 @@ def init(model_path: Optional[str] = None, config: Optional[LangKitConfig] = Non
config = config or deepcopy(lang_config)
model_path = model_path or config.toxicity_model_path
global _toxicity_model
if model_path == "martin-ha/toxic-comment-model":
_toxicity_model = ToxicCommentModel(model_path)
elif model_path == "detoxify/unbiased":
if model_path == "detoxify/unbiased":
_toxicity_model = DetoxifyModel("unbiased")
elif model_path == "detoxify/original":
_toxicity_model = DetoxifyModel("original")
elif model_path == "detoxify/multilingual":
_toxicity_model = DetoxifyModel("multilingual")
else:
raise ValueError(f"Unknown toxicity model: {model_path}")
else: # assume it's martin-ha/toxic-comment-model, remote or from local path
_toxicity_model = ToxicCommentModel(model_path)


init()
Loading

0 comments on commit a818af7

Please sign in to comment.