Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
split apart methods and update examples
Browse files Browse the repository at this point in the history
Sean Friedowitz committed Jan 23, 2024
1 parent 31d8838 commit 883edcd
Showing 5 changed files with 95 additions and 97 deletions.
37 changes: 22 additions & 15 deletions examples/dataset_preprocessing.ipynb
Original file line number Diff line number Diff line change
@@ -46,16 +46,13 @@
"def preprocess_dataset(examples):\n",
" texts = []\n",
" for x in examples[\"prompt\"]:\n",
" texts.append(x[::-1]) # Dummy reverse the prompt\n",
" texts.append(x[::-1]) # Dummy reverse the prompt\n",
" examples[\"text\"] = texts\n",
" return examples\n",
"\n",
"\n",
"# Map some preprocessing function over the base dataset (e.g., for prompt formatting)\n",
"dataset = dataset.map(\n",
" preprocess_dataset, \n",
" batched=True,\n",
" remove_columns=dataset.column_names\n",
")\n",
"dataset = dataset.map(preprocess_dataset, batched=True, remove_columns=dataset.column_names)\n",
"\n",
"dataset"
]
@@ -93,14 +90,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "8b09e47d-3ced-4eef-a89f-048754edc758",
"metadata": {},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"from flamingo.integrations.wandb import ArtifactType, ArtifactURIScheme, log_artifact_from_path\n",
"from flamingo.integrations.wandb import (\n",
" ArtifactType,\n",
" ArtifactURIScheme,\n",
" log_directory_contents,\n",
" log_directory_reference,\n",
")\n",
"from flamingo.jobs.utils import FlamingoJobType"
]
},
@@ -115,15 +117,20 @@
" name=\"dataset-preprocessing-example\",\n",
" project=\"sfriedowitz-dev\",\n",
" entity=\"mozilla-ai\",\n",
" job_type=FlamingoJobType.PREPROCESSING\n",
" job_type=FlamingoJobType.PREPROCESSING,\n",
"):\n",
" # Specify that this path references local files by adding `uri_scheme = ArtifactURIScheme.FILE`\n",
" # This will upload a reference to the dataset, rather than uploading the actual files\n",
" log_artifact_from_path(\n",
" name=\"example-dataset-artfact\",\n",
" path=dataset_save_path,\n",
" # Log a reference to the directory contents\n",
" log_directory_reference(\n",
" dir_path=dataset_save_path,\n",
" artifact_name=\"example-dataset-artfact-reference\",\n",
" artifact_type=ArtifactType.DATASET,\n",
" scheme=ArtifactURIScheme.FILE,\n",
" )\n",
" # Log and upload the directory contents\n",
" log_directory_contents(\n",
" dir_path=dataset_save_path,\n",
" artifact_name=\"example-dataset-artfact-upload\",\n",
" artifact_type=ArtifactType.DATASET,\n",
" uri_scheme=ArtifactURIScheme.FILE\n",
" )"
]
}
64 changes: 13 additions & 51 deletions examples/dev_workflow.ipynb
Original file line number Diff line number Diff line change
@@ -10,14 +10,15 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "8c0f15ed-77dc-44ce-adb6-d1b59368f03c",
"metadata": {},
"outputs": [],
"source": [
"# Required imports\n",
"import os\n",
"from pathlib import Path\n",
"\n",
"from ray.job_submission import JobSubmissionClient\n",
"\n",
"# flamingo should be installed in your development environment\n",
@@ -26,19 +27,19 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"id": "969884e5-d815-42d9-9d4e-3b8f890657e2",
"metadata": {},
"outputs": [],
"source": [
"# Create a submission client bound to a Ray cluster\n",
"# Note: You will likely have to update the cluster address shown below\n",
"client = JobSubmissionClient(f\"http://10.146.174.91:8265\")"
"client = JobSubmissionClient(\"http://10.146.174.91:8265\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"id": "3258bb97-d3c6-4fee-aa0c-962c1411eaa7",
"metadata": {},
"outputs": [],
@@ -49,21 +50,10 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": null,
"id": "1db3b9aa-99a4-49d9-8773-7b91ccf89c85",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SimpleJobConfig(magic_number=42)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Load and inspect the config file\n",
"# Not mandatory for job submission, but helpful when debugging\n",
@@ -78,7 +68,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"id": "b81b36be-35ce-4398-a6d4-ac1f719f5c95",
"metadata": {},
"outputs": [],
@@ -88,53 +78,25 @@
"# pip contains an export of the dependencies for the flamingo package (see CONTRIBUTING.md for how to generate)\n",
"runtime_env = {\n",
" \"working_dir\": str(CONFIG_DIR),\n",
" \"env_vars\": {\"WANDB_API_KEY\": os.environ[\"WANDB_API_KEY\"]}, # If running a job that uses W&B\n",
" \"env_vars\": {\"WANDB_API_KEY\": os.environ[\"WANDB_API_KEY\"]}, # If running a job that uses W&B\n",
" \"py_modules\": [str(flamingo_module)],\n",
" \"pip\": \"/path/to/flamingo/requirements.txt\"\n",
" \"pip\": \"/path/to/flamingo/requirements.txt\",\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"id": "4bd300f9-b863-4413-bd3a-430601656816",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-01-20 15:32:25,620\tINFO dashboard_sdk.py:385 -- Package gcs://_ray_pkg_ba0036a72fdb32af.zip already exists, skipping upload.\n",
"2024-01-20 15:32:25,814\tINFO dashboard_sdk.py:385 -- Package gcs://_ray_pkg_8f96eb40a239b233.zip already exists, skipping upload.\n"
]
},
{
"data": {
"text/plain": [
"'raysubmit_tWfixDMGHavrhHPF'"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Submit the job to the Ray cluster\n",
"# Note: flamingo is invoked by 'python -m flamingo' since the CLI is not installed in the environment\n",
"client.submit_job(\n",
" entrypoint=f\"python -m flamingo run simple --config {CONFIG_FILE}\",\n",
" runtime_env=runtime_env\n",
" entrypoint=f\"python -m flamingo run simple --config {CONFIG_FILE}\", runtime_env=runtime_env\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c82892d-bcdf-42e6-b95e-2393e01ab7d6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@ testpaths = ["tests"]

[tool.ruff]
target-version = "py310"

exclude = [
".bzr",
".direnv",
@@ -72,6 +73,9 @@ exclude = [
"node_modules",
"venv",
]

extend-include = ["*.ipynb"]

line-length = 100

[tool.ruff.extend-per-file-ignores]
@@ -82,6 +86,10 @@ line-length = 100


[tool.ruff.lint]

# Only format Jupyter notebooks, but don't lint them
exclude = ["*.ipynb"]

select = [
"E", # pycodestyle
"F", # pyflakes
70 changes: 46 additions & 24 deletions src/flamingo/integrations/wandb/artifact_utils.py
Original file line number Diff line number Diff line change
@@ -66,42 +66,64 @@ def get_artifact_filesystem_path(
return artifact.download(root=download_root)


def log_artifact_from_path(
name: str,
path: str | Path,
def log_directory_contents(
dir_path: str | Path,
artifact_name: str,
artifact_type: ArtifactType,
*,
uri_scheme: ArtifactURIScheme | None = None,
max_objects: int | None = None,
entry_name: str | None = None,
) -> wandb.Artifact:
"""Log an artifact containing the contents of a directory to the currently active run.
"""Log the contents of a directory as an artifact of the active run.
A run should already be initialized before calling this method.
If not, an exception will be thrown.
Example usage:
```
with wandb_init_from_config(run_config):
log_artifact_from_path(...)
```
Args:
dir_path (str | Path): Path to the artifact directory.
artifact_name (str): Name of the artifact.
artifact_type (ArtifactType): Type of the artifact to create.
entry_name (str, optional): Name within the artifact to add the directory contents.
Returns:
The `wandb.Artifact` that was produced
"""
artifact = wandb.Artifact(name=artifact_name, type=artifact_type)
artifact.add_dir(str(dir_path), name=entry_name)
return wandb.log_artifact(artifact)


def log_directory_reference(
dir_path: str | Path,
artifact_name: str,
artifact_type: ArtifactType,
*,
scheme: ArtifactURIScheme = ArtifactURIScheme.FILE,
entry_name: str | None = None,
max_objects: int | None = None,
) -> wandb.Artifact:
"""Log a reference to a directory's contents as an artifact of the active run.
A run should already be initialized before calling this method.
If not, an exception will be thrown.
Args:
name (str): Name of the artifact
path (str | Path): Path to the artifact directory
artifact_type (ArtifactType): Type of the artifact to create
uri_scheme (ArtifactURIScheme, optional): URI scheme to prepend to the artifact path.
When provided, the artifact is logged as a reference to this path.
dir_path (str | Path): Path to the artifact directory.
artifact_name (str): Name of the artifact.
artifact_type (ArtifactType): Type of the artifact to create.
scheme (ArtifactURIScheme): URI scheme to prepend to the artifact path.
Defaults to `ArtifactURIScheme.FILE` for filesystem references.
entry_name (str, optional): Name within the artifact to add the directory reference.
max_objects (int, optional): Max number of objects allowed in the artifact.
Only used when creating reference artifacts.
Returns:
The `wandb.Artifact` that was logged
The `wandb.Artifact` that was produced
"""
artifact = wandb.Artifact(name=name, type=artifact_type)
if uri_scheme is not None:
artifact.add_reference(f"{uri_scheme}://{path}", max_objects=max_objects)
else:
artifact.add_dir(str(path))
# Log artifact to the currently active run
artifact = wandb.Artifact(name=artifact_name, type=artifact_type)
artifact.add_reference(
uri=f"{scheme}://{dir_path}",
name=entry_name,
max_objects=max_objects,
)
return wandb.log_artifact(artifact)
13 changes: 6 additions & 7 deletions src/flamingo/jobs/finetuning/entrypoint.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
ArtifactType,
ArtifactURIScheme,
default_artifact_name,
log_artifact_from_path,
log_directory_reference,
wandb_init_from_config,
)
from flamingo.jobs.finetuning import FinetuningJobConfig
@@ -149,11 +149,10 @@ def run_finetuning(config: FinetuningJobConfig):
if config.tracking and result.checkpoint:
# Must resume from the just-completed training run
with wandb_init_from_config(config.tracking, resume="must") as run:
print("Building artifact for model checkpoint...")
artifact_name = default_artifact_name(run.name, ArtifactType.MODEL)
log_artifact_from_path(
name=artifact_name,
path=f"{result.checkpoint.path}/checkpoint",
print("Logging artifact for model checkpoint...")
log_directory_reference(
dir_path=f"{result.checkpoint.path}/checkpoint",
artifact_name=default_artifact_name(run.name, ArtifactType.MODEL),
artifact_type=ArtifactType.MODEL,
uri_scheme=ArtifactURIScheme.FILE,
scheme=ArtifactURIScheme.FILE,
)

0 comments on commit 883edcd

Please sign in to comment.