This repository has been archived by the owner on Sep 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Sean Friedowitz
committed
Jan 11, 2024
1 parent
9db31da
commit 25a5785
Showing
15 changed files
with
85 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,9 @@ | ||
from .model_name_or_path import ModelNameOrCheckpointPath | ||
from .quantization_config import QuantizationConfig | ||
from .utils import is_valid_huggingface_model_name | ||
from .trainer_config import TrainerConfig | ||
|
||
__all__ = ["QuantizationConfig", "is_valid_huggingface_model_name"] | ||
__all__ = [ | ||
"ModelNameOrCheckpointPath", | ||
"QuantizationConfig", | ||
"TrainerConfig", | ||
] |
49 changes: 49 additions & 0 deletions
49
src/flamingo/integrations/huggingface/model_name_or_path.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from dataclasses import InitVar | ||
from pathlib import Path | ||
|
||
from huggingface_hub.utils import HFValidationError, validate_repo_id | ||
from pydantic.dataclasses import dataclass | ||
|
||
|
||
def is_valid_huggingface_model_name(s: str): | ||
""" | ||
Simple test to check if an HF model is valid using HuggingFace's tools. | ||
Sadly, theirs throws an exception and has no return. | ||
Args: | ||
s: string to test. | ||
""" | ||
try: | ||
validate_repo_id(s) | ||
return True | ||
except HFValidationError: | ||
return False | ||
|
||
|
||
@dataclass | ||
class ModelNameOrCheckpointPath: | ||
""" | ||
This class is explicitly used to validate if a string is | ||
a valid HuggingFace model or can be used as a checkpoint. | ||
Checkpoint will be automatically assigned if it's a valid checkpoint; | ||
it will be None if it's not valid. | ||
""" | ||
|
||
# explictly needed for matching | ||
__match_args__ = ("name", "checkpoint") | ||
|
||
name: str | ||
checkpoint: InitVar[str | None] = None | ||
|
||
def __post_init__(self, checkpoint): | ||
if isinstance(self.name, Path): | ||
self.name = str(self.name) | ||
|
||
if Path(self.name).is_absolute(): | ||
self.checkpoint = self.name | ||
else: | ||
self.checkpoint = None | ||
|
||
if self.checkpoint is None and not is_valid_huggingface_model_name(self.name): | ||
raise (ValueError(f"{self.name} is not a valid checkpoint path or HF model name")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from flamingo.types import BaseFlamingoConfig | ||
|
||
|
||
class TrainerConfig(BaseFlamingoConfig): | ||
pass |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from flamingo.jobs.configs import SimpleJobConfig | ||
|
||
|
||
def main(config: SimpleJobConfig): | ||
def run(config: SimpleJobConfig): | ||
print(f"The magic number is {config.magic_number}") |