Skip to content

Commit

Permalink
Merge pull request facebookresearch#257 from NIXBLACK11/Language_mode…
Browse files Browse the repository at this point in the history
…l_validation

Adding Language Validation Test
  • Loading branch information
heffernankevin authored Nov 14, 2023
2 parents 3c5f5ed + 87a08e9 commit b0131d9
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 5 deletions.
13 changes: 10 additions & 3 deletions laser_encoders/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def download(self, filename: str):
def get_language_code(self, language_list: dict, lang: str) -> str:
try:
lang_3_4 = language_list[lang]
if isinstance(lang_3_4, tuple):
if isinstance(lang_3_4, list):
options = ", ".join(f"'{opt}'" for opt in lang_3_4)
raise ValueError(
f"Language '{lang_3_4}' has multiple options: {options}. Please specify using --lang."
f"Language '{lang}' has multiple options: {options}. Please specify using the 'lang' argument."
)
return lang_3_4
except KeyError:
Expand All @@ -88,7 +88,14 @@ def download_laser2(self):
self.download("laser2.cvocab")

def download_laser3(self, lang: str, spm: bool = False):
lang = self.get_language_code(LASER3_LANGUAGE, lang)
result = self.get_language_code(LASER3_LANGUAGE, lang)

if isinstance(result, list):
raise ValueError(
f"There are script-specific models available for {lang}. Please choose one from the following: {result}"
)

lang = result
self.download(f"laser3-{lang}.v1.pt")
if spm:
if lang in SPM_LANGUAGE:
Expand Down
4 changes: 3 additions & 1 deletion laser_encoders/laser_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,14 @@ def initialize_tokenizer(lang: str = None, model_dir: str = None, laser: str = N
f"Unsupported laser model: {laser}. Choose either laser2 or laser3."
)
else:
if lang in LASER3_LANGUAGE or lang in LASER2_LANGUAGE:
if lang in LASER3_LANGUAGE:
lang = downloader.get_language_code(LASER3_LANGUAGE, lang)
if lang in SPM_LANGUAGE:
filename = f"laser3-{lang}.v1.spm"
else:
filename = "laser2.spm"
elif lang in LASER2_LANGUAGE:
filename = "laser2.spm"
else:
raise ValueError(
f"Unsupported language name: {lang}. Please specify a supported language name."
Expand Down
2 changes: 1 addition & 1 deletion laser_encoders/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@ def initialize_encoder(
f"Unsupported laser model: {laser}. Choose either laser2 or laser3."
)
else:
lang = downloader.get_language_code(LASER3_LANGUAGE, lang)
if lang in LASER3_LANGUAGE:
lang = downloader.get_language_code(LASER3_LANGUAGE, lang)
downloader.download_laser3(lang=lang, spm=spm)
file_path = f"laser3-{lang}.v1"
elif lang in LASER2_LANGUAGE:
Expand Down
57 changes: 57 additions & 0 deletions laser_encoders/test_models_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
import tempfile

import pytest

from laser_encoders.download_models import LaserModelDownloader
from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE
from laser_encoders.laser_tokenizer import initialize_tokenizer
from laser_encoders.models import initialize_encoder


def test_validate_achnese_models_and_tokenize_laser3(lang="acehnese"):
with tempfile.TemporaryDirectory() as tmp_dir:
print(f"Created temporary directory for {lang}", tmp_dir)

downloader = LaserModelDownloader(model_dir=tmp_dir)
downloader.download_laser3(lang)
encoder = initialize_encoder(lang, model_dir=tmp_dir)
tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir)

# Test tokenization with a sample sentence
tokenized = tokenizer.tokenize("This is a sample sentence.")

print(f"{lang} model validated successfully")


def test_validate_english_models_and_tokenize_laser2(lang="english"):
with tempfile.TemporaryDirectory() as tmp_dir:
print(f"Created temporary directory for {lang}", tmp_dir)

downloader = LaserModelDownloader(model_dir=tmp_dir)
downloader.download_laser2()

encoder = initialize_encoder(lang, model_dir=tmp_dir)
tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir)

# Test tokenization with a sample sentence
tokenized = tokenizer.tokenize("This is a sample sentence.")

print(f"{lang} model validated successfully")


def test_validate_kashmiri_models_and_tokenize_laser3(lang="kas"):
with tempfile.TemporaryDirectory() as tmp_dir:
print(f"Created temporary directory for {lang}", tmp_dir)

downloader = LaserModelDownloader(model_dir=tmp_dir)
with pytest.raises(ValueError):
downloader.download_laser3(lang)

encoder = initialize_encoder(lang, model_dir=tmp_dir)
tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir)

# Test tokenization with a sample sentence
tokenized = tokenizer.tokenize("This is a sample sentence.")

print(f"{lang} model validated successfully")
108 changes: 108 additions & 0 deletions laser_encoders/validate_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os
import tempfile

import pytest

from laser_encoders.download_models import LaserModelDownloader
from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE
from laser_encoders.laser_tokenizer import initialize_tokenizer
from laser_encoders.models import initialize_encoder


@pytest.mark.slow
@pytest.mark.parametrize("lang", LASER3_LANGUAGE)
def test_validate_language_models_and_tokenize_laser3(lang):
with tempfile.TemporaryDirectory() as tmp_dir:
print(f"Created temporary directory for {lang}", tmp_dir)

downloader = LaserModelDownloader(model_dir=tmp_dir)
if lang in ["kashmiri", "kas", "central kanuri", "knc"]:
with pytest.raises(ValueError) as excinfo:
downloader.download_laser3(lang)
assert "ValueError" in str(excinfo.value)
print(f"{lang} language model raised a ValueError as expected.")
else:
downloader.download_laser3(lang)
encoder = initialize_encoder(lang, model_dir=tmp_dir)
tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir)

# Test tokenization with a sample sentence
tokenized = tokenizer.tokenize("This is a sample sentence.")

print(f"{lang} model validated successfully")


@pytest.mark.slow
@pytest.mark.parametrize("lang", LASER2_LANGUAGE)
def test_validate_language_models_and_tokenize_laser2(lang):
with tempfile.TemporaryDirectory() as tmp_dir:
print(f"Created temporary directory for {lang}", tmp_dir)

downloader = LaserModelDownloader(model_dir=tmp_dir)
downloader.download_laser2()

encoder = initialize_encoder(lang, model_dir=tmp_dir)
tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir)

# Test tokenization with a sample sentence
tokenized = tokenizer.tokenize("This is a sample sentence.")

print(f"{lang} model validated successfully")


class MockLaserModelDownloader(LaserModelDownloader):
def __init__(self, model_dir):
self.model_dir = model_dir

def download_laser3(self, lang):
lang = self.get_language_code(LASER3_LANGUAGE, lang)
file_path = os.path.join(self.model_dir, f"laser3-{lang}.v1.pt")
if not os.path.exists(file_path):
raise FileNotFoundError(f"Could not find {file_path}.")

def download_laser2(self):
files = ["laser2.pt", "laser2.spm", "laser2.cvocab"]
for file_name in files:
file_path = os.path.join(self.model_dir, file_name)
if not os.path.exists(file_path):
raise FileNotFoundError(f"Could not find {file_path}.")


CACHE_DIR = "/home/user/.cache/models" # Change this to the desired cache directory

# This uses the mock downloader
@pytest.mark.slow
@pytest.mark.parametrize("lang", LASER3_LANGUAGE)
def test_validate_language_models_and_tokenize_mock_laser3(lang):
downloader = MockLaserModelDownloader(model_dir=CACHE_DIR)

try:
downloader.download_laser3(lang)
except FileNotFoundError as e:
raise pytest.error(str(e))

encoder = initialize_encoder(lang, model_dir=CACHE_DIR)
tokenizer = initialize_tokenizer(lang, model_dir=CACHE_DIR)

tokenized = tokenizer.tokenize("This is a sample sentence.")

print(f"{lang} model validated successfully")


# This uses the mock downloader
@pytest.mark.slow
@pytest.mark.parametrize("lang", LASER2_LANGUAGE)
def test_validate_language_models_and_tokenize_mock_laser2(lang):
downloader = MockLaserModelDownloader(model_dir=CACHE_DIR)

try:
downloader.download_laser2()
except FileNotFoundError as e:
raise pytest.error(str(e))

encoder = initialize_encoder(lang, model_dir=CACHE_DIR)
tokenizer = initialize_tokenizer(lang, model_dir=CACHE_DIR)

tokenized = tokenizer.tokenize("This is a sample sentence.")

print(f"{lang} model validated successfully")

0 comments on commit b0131d9

Please sign in to comment.