Skip to content

Commit

Permalink
Move s3 processing outside autolabel
Browse files Browse the repository at this point in the history
  • Loading branch information
DhruvaBansal00 committed Dec 4, 2024
1 parent 80f607e commit c7fb84b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 83 deletions.
46 changes: 1 addition & 45 deletions src/autolabel/task_chain/task_chain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
import logging
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional

import boto3
import pandas as pd
Expand All @@ -19,7 +18,6 @@
from autolabel.few_shot.base_label_selector import BaseLabelSelector
from autolabel.labeler import LabelingAgent
from autolabel.transforms import TransformFactory
from autolabel.utils import generate_presigned_url, is_s3_uri

logger = logging.getLogger(__name__)
logging.getLogger("httpx").setLevel(logging.WARNING)
Expand Down Expand Up @@ -159,10 +157,6 @@ async def run(self, dataset_df: pd.DataFrame):
for task in subtasks:
autolabel_config = AutolabelConfig(task)
dataset = AutolabelDataset(dataset_df, autolabel_config)
dataset, original_inputs = self.safe_convert_uri_to_presigned_url(
dataset,
autolabel_config,
)
if autolabel_config.transforms():
agent = LabelingAgent(
config=autolabel_config,
Expand Down Expand Up @@ -199,11 +193,6 @@ async def run(self, dataset_df: pd.DataFrame):
dataset,
skip_eval=True,
)
dataset = self.reset_presigned_url_to_uri(
dataset,
original_inputs,
autolabel_config,
)
dataset = self.rename_output_columns(dataset, autolabel_config)
dataset_df = dataset.df
return dataset
Expand Down Expand Up @@ -233,36 +222,3 @@ def rename_output_columns(
].apply(lambda x: x.get(attribute) if x and type(x) is dict else None)

return dataset

def safe_convert_uri_to_presigned_url(
self,
dataset: AutolabelDataset,
autolabel_config: AutolabelConfig,
) -> Tuple[AutolabelDataset, List[Dict]]:
original_inputs = copy.deepcopy(dataset.inputs)
for col in autolabel_config.input_columns():
for i in range(len(dataset.inputs)):
if col in dataset.inputs[i]:
dataset.inputs[i][col] = (
generate_presigned_url(
self.s3_client,
dataset.inputs[i][col],
)
if is_s3_uri(dataset.inputs[i][col])
else dataset.inputs[i][col]
)
dataset.df.loc[i, col] = dataset.inputs[i][col]
return dataset, original_inputs

def reset_presigned_url_to_uri(
self,
dataset: AutolabelDataset,
original_inputs: List[Dict],
autolabel_config: AutolabelConfig,
) -> AutolabelDataset:
for col in autolabel_config.input_columns():
for i in range(len(dataset.inputs)):
if col in dataset.inputs[i] and col in original_inputs[i]:
dataset.inputs[i][col] = original_inputs[i][col]
dataset.df.loc[i, col] = dataset.inputs[i][col]
return dataset
42 changes: 4 additions & 38 deletions src/autolabel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import string
from string import Formatter
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union
from urllib.parse import urlparse

import regex
import wget
Expand Down Expand Up @@ -291,7 +290,8 @@ def track_with_stats(
with live:
progress_task = progress.add_task(description=description, total=total)
stats_task = stats_progress.add_task(
"Stats", stats=", ".join(f"{k}={v}" for k, v in stats.items()),
"Stats",
stats=", ".join(f"{k}={v}" for k, v in stats.items()),
)
for value in sequence:
yield value
Expand All @@ -300,7 +300,8 @@ def track_with_stats(
advance=min(advance, total - progress.tasks[progress_task].completed),
)
stats_progress.update(
stats_task, stats=", ".join(f"{k}={v}" for k, v in stats.items()),
stats_task,
stats=", ".join(f"{k}={v}" for k, v in stats.items()),
)
live.refresh()

Expand Down Expand Up @@ -439,38 +440,3 @@ def safe_serialize_to_string(data: Dict) -> Dict:
except Exception:
ret[k] = ""
return ret


def is_s3_uri(uri_string: str) -> bool:
return uri_string is not None and (
uri_string.startswith("s3://") or uri_string.startswith("s3a://")
)


def extract_bucket_key_from_s3_url(s3_path: str):
# Refer: https://stackoverflow.com/a/48245084
if not is_s3_uri(s3_path):
logger.warning("URI is not actually an S3 URI: {}", s3_path)
return None

path_object = urlparse(s3_path)
bucket = path_object.netloc
key = path_object.path
return {"Bucket": bucket, "Key": key.lstrip("/")}


def generate_s3_uri_from_bucket_key(bucket: str, key: str) -> str:
return f"s3://{bucket}/{key}"


def generate_presigned_url(client, s3_uri, expiration=86400):
s3_params = extract_bucket_key_from_s3_url(s3_uri)

if not s3_params:
return s3_uri

return client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": s3_params["Bucket"], "Key": s3_params["Key"]},
ExpiresIn=expiration,
)

0 comments on commit c7fb84b

Please sign in to comment.