Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented custom dataset creator class #962

Open
wants to merge 5 commits into
base: nextjs
Choose a base branch
from
Open

Conversation

NMBridges
Copy link
Contributor

Added CustomDatasetCreator class

What user problem are we solving?
#913
What solution does this PR provide?
Allows backend to use user-uploaded tabular datasets for training
Testing Methodology
Uploaded testing dataset to dlp-upload-bucket/nolan/tabular/antennae-lengths.csv. Performed function call on endpoint and verified that the data was preserved during download. Also implemented automatic string encoding for non-numeric labels.

Any other considerations

@NMBridges
Copy link
Contributor Author

teehee = CustomDatasetCreator.read_s3("nolan", "antennae-length.csv", 0.2, "label", True)
train = teehee.createTrainDataset()
test = teehee.createTestDataset()

training/training/core/dataset.py Outdated Show resolved Hide resolved
training/training/core/dataset.py Outdated Show resolved Hide resolved
shuffle: bool = True,
):
s3 = boto3.client("s3")
obj = s3.get_object(Bucket="dlp-upload-bucket", Key=f"{uid}/tabular/{name}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should allow models other than tabular to access the uploaded datasets but this is okay for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: add error handling if such a directory doesn't exist

Copy link
Member

@karkir0003 karkir0003 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

respond to my questions/comments

y = data[target_name]
X = data.drop(target_name, axis=1)
if y.apply(pd.to_numeric, errors="coerce").isnull().any():
le = LabelEncoder()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@farisdurrani not sure if we need this? If so, should we have a way to track label encoder so that when we build confusion matrix, we have a mapping of number to label?

im having a hard time finding how we solved this problem in the past version of our code?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't recall on top of my head but I believe we did encode the headers in our original code, since the confusion matrix generated only contains numbers. It has been a WIP to map the encodes back to the original labels

image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to store label encoder object for this case in order to recover the original labels? @farisdurrani

If not, any simpler way?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we do. We may be able to store an encoding in the metadata of the uploaded dataset but that's unnecessarily complicated. So just do it manually, passing along the encoder object down the functions

@@ -98,3 +102,64 @@ def getCategoryList(self) -> list[str]:
if self._category_list is None:
raise Exception("Category list not available")
return self._category_list


class CustomDatasetCreator(TrainTestDatasetCreator):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@farisdurrani should we name this class TabularCustomDatasetCreator if the scope is tabular? then we can still preserve extensibility

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NMBridges food for thought

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. Dataset can mean anything, adding Tabular to the name makes it more specific

@karkir0003
Copy link
Member

along with addressing any necessary changes in pr

Copy link
Member

@karkir0003 karkir0003 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added other questions

s3 = boto3.client("s3")
obj = s3.get_object(Bucket="dlp-upload-bucket", Key=f"{uid}/tabular/{name}")
data = pd.read_csv(io.BytesIO(obj["Body"].read()))
y = data[target_name]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@farisdurrani @dwu359 can we guarantee that at the invocation of this function, name of the target col from user uploaded csv dataset to s3 would be available?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should already be a part of the frontend to peek the headers of the csv files from s3 so the users can select the target and feature names. The frontend will send the Trainspace data which includes the target/feature names to training. So, yes. The target col names should be available at this point

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you confirm? @farisdurrani

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to test this code to make sure it works fine for default and uploaded datasets but I will say yes for now

@NMBridges NMBridges requested a review from a team as a code owner September 27, 2023 02:45
from abc import ABC, abstractmethod
from typing import Callable, Optional, Union, cast

from numpy import ndarray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [pyright] reported by reviewdog 🐶
Import "numpy" could not be resolved (reportMissingImports)

from typing import Callable, Optional, Union, cast

from numpy import ndarray
from sklearn.model_selection import train_test_split
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [pyright] reported by reviewdog 🐶
Import "sklearn.model_selection" could not be resolved (reportMissingImports)


from numpy import ndarray
from sklearn.model_selection import train_test_split
from sklearn.utils import Bunch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [pyright] reported by reviewdog 🐶
Import "sklearn.utils" could not be resolved (reportMissingImports)

from numpy import ndarray
from sklearn.model_selection import train_test_split
from sklearn.utils import Bunch
from sklearn.conftest import fetch_california_housing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [pyright] reported by reviewdog 🐶
Import "sklearn.conftest" could not be resolved (reportMissingImports)

from sklearn.model_selection import train_test_split
from sklearn.utils import Bunch
from sklearn.conftest import fetch_california_housing
from sklearn.datasets import load_breast_cancer, load_diabetes, load_iris, load_wine
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [pyright] reported by reviewdog 🐶
Import "sklearn.datasets" could not be resolved (reportMissingImports)

from torch.utils.data import TensorDataset
import numpy as np
import pandas as pd
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [pyright] reported by reviewdog 🐶
Import "torch" could not be resolved (reportMissingImports)

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [pyright] reported by reviewdog 🐶
Import "torch.utils.data" could not be resolved (reportMissingImports)

import pandas as pd
import torch
from torch.utils.data import Dataset
from torch.autograd import Variable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [pyright] reported by reviewdog 🐶
Import "torch.autograd" could not be resolved (reportMissingImports)

from torch.utils.data import Dataset
from torch.autograd import Variable

from sklearn.preprocessing import LabelEncoder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [pyright] reported by reviewdog 🐶
Import "sklearn.preprocessing" could not be resolved (reportMissingImports)

from torch.autograd import Variable

from sklearn.preprocessing import LabelEncoder
import boto3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [pyright] reported by reviewdog 🐶
Import "boto3" could not be resolved (reportMissingImports)

@sonarqubecloud
Copy link

SonarCloud Quality Gate failed.    Quality Gate failed

Bug A 0 Bugs
Vulnerability A 0 Vulnerabilities
Security Hotspot A 0 Security Hotspots
Code Smell A 3 Code Smells

No Coverage information No Coverage information
56.9% 56.9% Duplication

idea Catch issues before they fail your Quality Gate with our IDE extension sonarlint SonarLint

Copy link
Member

@farisdurrani farisdurrani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the build errors and resolve all comments, this is good to go for me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants