This repository has been archived by the owner on Sep 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Sean Friedowitz
committed
Jan 16, 2024
1 parent
0306f3b
commit a03d272
Showing
13 changed files
with
105 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,19 @@ | ||
from peft import LoraConfig | ||
from pydantic import validator | ||
|
||
from flamingo.integrations.huggingface import QuantizationConfig | ||
from flamingo.integrations.huggingface.utils import repo_name_validator | ||
from flamingo.integrations.wandb import WandbArtifactLink | ||
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype | ||
|
||
|
||
class AutoModelConfig(BaseFlamingoConfig): | ||
"""Settings passed to a HuggingFace AutoModel instantiation.""" | ||
"""Settings passed to a HuggingFace AutoModel instantiation. | ||
The model path can either be a string corresponding to a HuggingFace repo ID, | ||
or an artifact link to a reference artifact on W&B. | ||
""" | ||
|
||
path: str | WandbArtifactLink | ||
trust_remote_code: bool = False | ||
torch_dtype: SerializableTorchDtype = None | ||
quantization: QuantizationConfig | None = None | ||
lora: LoraConfig | None = None # TODO: Create own dataclass here | ||
|
||
_path_validator = validator("path", allow_reuse=True, pre=True)(repo_name_validator) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,42 @@ | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
import wandb | ||
from wandb.apis.public import Run | ||
|
||
from flamingo.integrations.wandb import WandbEnvironment | ||
from flamingo.integrations.wandb.wandb_artifact_link import WandbArtifactLink | ||
from flamingo.integrations.wandb import WandbArtifactLink, WandbEnvironment | ||
|
||
|
||
def get_wandb_artifact(link: WandbArtifactLink): | ||
def get_wandb_run(env: WandbEnvironment) -> Run: | ||
"""Retrieve a run from the W&B API.""" | ||
api = wandb.Api() | ||
return api.artifact(link.artifact_path()) | ||
return api.run(env.wandb_path) | ||
|
||
|
||
def get_wandb_summary(env: WandbEnvironment) -> dict[str, Any]: | ||
"""Get the summary dictionary attached to a W&B run.""" | ||
run = _resolve_wandb_run(env) | ||
run = get_wandb_run(env) | ||
return dict(run.summary) | ||
|
||
|
||
def update_wandb_summary(env: WandbEnvironment, metrics: dict[str, Any]) -> None: | ||
"""Update a run's summary with the provided metrics.""" | ||
run = _resolve_wandb_run(env) | ||
run = get_wandb_run(env) | ||
run.summary.update(metrics) | ||
run.update() | ||
|
||
|
||
def _resolve_wandb_run(env: WandbEnvironment) -> Run: | ||
"""Resolve a WandB run object from the provided environment settings. | ||
An exception is raised if a Run cannot be found, | ||
or if multiple runs exist in scope with the same name. | ||
""" | ||
def get_wandb_artifact(link: WandbArtifactLink) -> wandb.Artifact: | ||
"""Retrieve an artifact from the W&B API.""" | ||
api = wandb.Api() | ||
base_path = "/".join(x for x in (env.entity, env.project) if x) | ||
if env.run_id is not None: | ||
full_path = f"{base_path}/{env.run_id}" | ||
return api.run(full_path) | ||
else: | ||
match [run for run in api.runs(base_path) if run.name == env.name]: | ||
case []: | ||
raise RuntimeError(f"No WandB runs found at {base_path}/{env.name}") | ||
case [Run(), _]: | ||
raise RuntimeError(f"Multiple WandB runs found at {base_path}/{env.name}") | ||
case [Run()] as mr: | ||
# we have a single one, hurray | ||
return mr[0] | ||
return api.artifact(link.wandb_path) | ||
|
||
|
||
def get_artifact_filesystem_path(link: WandbArtifactLink) -> str: | ||
# TODO: What if there are multiple folder paths in the artifact manifest? | ||
artifact = get_wandb_artifact(link) | ||
for entry in artifact.manifest.entries.values(): | ||
if entry.ref.startswith("file://"): | ||
entry_path = Path(entry.ref.replace("file://", "")) | ||
return str(entry_path.parent.absolute()) | ||
raise ValueError("Artifact does not contain reference to filesystem files.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,9 @@ | ||
from .finetuning_config import FinetuningJobConfig | ||
from .lm_harness_config import LMHarnessJobConfig, ModelNameOrCheckpointPath | ||
from .lm_harness_config import LMHarnessJobConfig | ||
from .simple_config import SimpleJobConfig | ||
|
||
__all__ = [ | ||
"SimpleJobConfig", | ||
"FinetuningJobConfig", | ||
"LMHarnessJobConfig", | ||
"ModelNameOrCheckpointPath", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.