-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented custom dataset creator class #962
base: nextjs
Are you sure you want to change the base?
Conversation
|
training/training/core/dataset.py
Outdated
shuffle: bool = True, | ||
): | ||
s3 = boto3.client("s3") | ||
obj = s3.get_object(Bucket="dlp-upload-bucket", Key=f"{uid}/tabular/{name}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should allow models other than tabular to access the uploaded datasets but this is okay for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: add error handling if such a directory doesn't exist
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
respond to my questions/comments
training/training/core/dataset.py
Outdated
y = data[target_name] | ||
X = data.drop(target_name, axis=1) | ||
if y.apply(pd.to_numeric, errors="coerce").isnull().any(): | ||
le = LabelEncoder() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@farisdurrani not sure if we need this? If so, should we have a way to track label encoder so that when we build confusion matrix, we have a mapping of number to label?
im having a hard time finding how we solved this problem in the past version of our code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to store label encoder object for this case in order to recover the original labels? @farisdurrani
If not, any simpler way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we do. We may be able to store an encoding in the metadata of the uploaded dataset but that's unnecessarily complicated. So just do it manually, passing along the encoder object down the functions
training/training/core/dataset.py
Outdated
@@ -98,3 +102,64 @@ def getCategoryList(self) -> list[str]: | |||
if self._category_list is None: | |||
raise Exception("Category list not available") | |||
return self._category_list | |||
|
|||
|
|||
class CustomDatasetCreator(TrainTestDatasetCreator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@farisdurrani should we name this class TabularCustomDatasetCreator if the scope is tabular? then we can still preserve extensibility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NMBridges food for thought
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good idea. Dataset can mean anything, adding Tabular to the name makes it more specific
along with addressing any necessary changes in pr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added other questions
training/training/core/dataset.py
Outdated
s3 = boto3.client("s3") | ||
obj = s3.get_object(Bucket="dlp-upload-bucket", Key=f"{uid}/tabular/{name}") | ||
data = pd.read_csv(io.BytesIO(obj["Body"].read())) | ||
y = data[target_name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@farisdurrani @dwu359 can we guarantee that at the invocation of this function, name of the target col from user uploaded csv dataset to s3 would be available?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should already be a part of the frontend to peek the headers of the csv files from s3 so the users can select the target and feature names. The frontend will send the Trainspace data which includes the target/feature names to training. So, yes. The target col names should be available at this point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you confirm? @farisdurrani
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need to test this code to make sure it works fine for default and uploaded datasets but I will say yes for now
from abc import ABC, abstractmethod | ||
from typing import Callable, Optional, Union, cast | ||
|
||
from numpy import ndarray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚫 [pyright] reported by reviewdog 🐶
Import "numpy" could not be resolved (reportMissingImports)
from typing import Callable, Optional, Union, cast | ||
|
||
from numpy import ndarray | ||
from sklearn.model_selection import train_test_split |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚫 [pyright] reported by reviewdog 🐶
Import "sklearn.model_selection" could not be resolved (reportMissingImports)
|
||
from numpy import ndarray | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.utils import Bunch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚫 [pyright] reported by reviewdog 🐶
Import "sklearn.utils" could not be resolved (reportMissingImports)
from numpy import ndarray | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.utils import Bunch | ||
from sklearn.conftest import fetch_california_housing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚫 [pyright] reported by reviewdog 🐶
Import "sklearn.conftest" could not be resolved (reportMissingImports)
from sklearn.model_selection import train_test_split | ||
from sklearn.utils import Bunch | ||
from sklearn.conftest import fetch_california_housing | ||
from sklearn.datasets import load_breast_cancer, load_diabetes, load_iris, load_wine |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚫 [pyright] reported by reviewdog 🐶
Import "sklearn.datasets" could not be resolved (reportMissingImports)
from torch.utils.data import TensorDataset | ||
import numpy as np | ||
import pandas as pd | ||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚫 [pyright] reported by reviewdog 🐶
Import "torch" could not be resolved (reportMissingImports)
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from torch.utils.data import Dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚫 [pyright] reported by reviewdog 🐶
Import "torch.utils.data" could not be resolved (reportMissingImports)
import pandas as pd | ||
import torch | ||
from torch.utils.data import Dataset | ||
from torch.autograd import Variable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚫 [pyright] reported by reviewdog 🐶
Import "torch.autograd" could not be resolved (reportMissingImports)
from torch.utils.data import Dataset | ||
from torch.autograd import Variable | ||
|
||
from sklearn.preprocessing import LabelEncoder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚫 [pyright] reported by reviewdog 🐶
Import "sklearn.preprocessing" could not be resolved (reportMissingImports)
from torch.autograd import Variable | ||
|
||
from sklearn.preprocessing import LabelEncoder | ||
import boto3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚫 [pyright] reported by reviewdog 🐶
Import "boto3" could not be resolved (reportMissingImports)
SonarCloud Quality Gate failed. 0 Bugs No Coverage information Catch issues before they fail your Quality Gate with our IDE extension SonarLint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the build errors and resolve all comments, this is good to go for me
Added CustomDatasetCreator class
What user problem are we solving?
#913
What solution does this PR provide?
Allows backend to use user-uploaded tabular datasets for training
Testing Methodology
Uploaded testing dataset to dlp-upload-bucket/nolan/tabular/antennae-lengths.csv. Performed function call on endpoint and verified that the data was preserved during download. Also implemented automatic string encoding for non-numeric labels.
Any other considerations