Skip to content

Commit

Permalink
Update validate_models.py
Browse files Browse the repository at this point in the history
  • Loading branch information
NIXBLACK11 authored Nov 14, 2023
1 parent 92345be commit 87a08e9
Showing 1 changed file with 1 addition and 15 deletions.
16 changes: 1 addition & 15 deletions laser_encoders/validate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,10 @@ def test_validate_language_models_and_tokenize_laser2(lang):
print(f"{lang} model validated successfully")


class MockLaserModelDownloader:
class MockLaserModelDownloader(LaserModelDownloader):
def __init__(self, model_dir):
self.model_dir = model_dir

def get_language_code(self, language_list: dict, lang: str) -> str:
try:
lang_3_4 = language_list[lang]
if isinstance(lang_3_4, tuple):
options = ", ".join(f"'{opt}'" for opt in lang_3_4)
raise ValueError(
f"Language '{lang_3_4}' has multiple options: {options}. Please specify using --lang."
)
return lang_3_4
except KeyError:
raise ValueError(
f"language name: {lang} not found in language list. Specify a supported language name"
)

def download_laser3(self, lang):
lang = self.get_language_code(LASER3_LANGUAGE, lang)
file_path = os.path.join(self.model_dir, f"laser3-{lang}.v1.pt")
Expand Down

0 comments on commit 87a08e9

Please sign in to comment.