Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - Vector DB investigation #143

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 287 additions & 6 deletions docs/notebooks/train_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"source": [
"import fibad\n",
"\n",
"fibad_instance = fibad.Fibad()"
"fibad_instance = fibad.Fibad(config_file=\"/home/drew/code/fibad/drews_config.toml\")"
]
},
{
Expand All @@ -33,10 +33,10 @@
"metadata": {},
"outputs": [],
"source": [
"fibad_instance.config[\"model\"][\"name\"] = \"ExampleCNN\"\n",
"fibad_instance.config[\"data_set\"][\"name\"] = \"CifarDataSet\"\n",
"fibad_instance.config[\"model\"][\"name\"] = \"ExampleAutoencoder\"\n",
"fibad_instance.config[\"data_set\"][\"name\"] = \"HSCDataSet\"\n",
"fibad_instance.config[\"data_loader\"][\"batch_size\"] = 64\n",
"fibad_instance.config[\"train\"][\"epochs\"] = 2"
"fibad_instance.config[\"train\"][\"epochs\"] = 20"
]
},
{
Expand Down Expand Up @@ -95,12 +95,293 @@
"# and then forward the selected port to your local machine"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Update the config with the trained model that we want to use and set a few other parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fibad_instance.config[\"predict\"][\n",
" \"model_weights_file\"\n",
"] = \"/home/drew/code/fibad/docs/notebooks/results/20241216-203332-train/example_model.pth\"\n",
"fibad_instance.config[\"predict\"][\"split\"] = \"test\"\n",
"fibad_instance.config[\"data_set\"][\"test_size\"] = 1.0\n",
"fibad_instance.config[\"data_set\"][\"train_size\"] = 0.0\n",
"fibad_instance.config[\"data_set\"][\"validate_size\"] = 0.0\n",
"fibad_instance.config[\"data_loader\"][\"batch_size\"] = 128"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run inference on the data set using the specified data and trained model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fibad_instance.predict()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Grab a copy of the PyTorch data_set object to use as a reference for file names."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prepped_output = fibad_instance.prepare()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define a couple of functions to help with plotting and open a connection to our vector database"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import chromadb\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from astropy.io import fits\n",
"\n",
"\n",
"# Function to normalize the data to the range [0, 1]\n",
"def normalize(data):\n",
" data_min = np.min(data)\n",
" data_max = np.max(data)\n",
" return (data - data_min) / (data_max - data_min)\n",
"\n",
"\n",
"# Plot our 3 filter images\n",
"def plotter(file_name):\n",
" # Read the FITS files\n",
" base_path = \"/home/drew/code/fibad/docs/notebooks/data/hsc_example/hsc_8asec_1000/\"\n",
" fits_file_r = base_path + file_name + \"_HSC-I.fits\"\n",
" fits_file_g = base_path + file_name + \"_HSC-R.fits\"\n",
" fits_file_b = base_path + file_name + \"_HSC-G.fits\"\n",
"\n",
" data_r = fits.getdata(fits_file_r)\n",
" data_g = fits.getdata(fits_file_g)\n",
" data_b = fits.getdata(fits_file_b)\n",
"\n",
" # Normalize the data\n",
" data_r = normalize(data_r)\n",
" data_g = normalize(data_g)\n",
" data_b = normalize(data_b)\n",
"\n",
" # Combine the data into an RGB image\n",
" rgb_image = np.zeros((data_r.shape[0], data_r.shape[1], 3))\n",
" rgb_image[..., 0] = data_r # Red channel\n",
" rgb_image[..., 1] = data_g # Green channel\n",
" rgb_image[..., 2] = data_b # Blue channel\n",
"\n",
" # Display the image\n",
" plt.imshow(rgb_image, origin=\"lower\")\n",
" plt.axis(\"off\") # Hide the axis\n",
" plt.show()\n",
"\n",
"\n",
"# open a connection to the vector database\n",
"client = chromadb.PersistentClient(path=\"/home/drew/code/fibad/docs/notebooks/results/vdb\")\n",
"collection = client.get_collection(\"fibad_collection\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load one of the .npy files that was saved when we ran the data through the trained model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"a = np.load(\"/home/drew/code/fibad/docs/notebooks/results/20241216-203830-predict/0.npy\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pick a few random embeddings from the file and query the vector database to find the most similar data samples.\n",
"\"Similar\" in this case is the L2 norm metric, $ d = \\sum_{} (A_i - B_i)^2 $.\n",
"\n",
"Cosine similarity and Inner product distance metrics are also supported."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"# 97 is a cool example\n",
"# 93 is a clean example\n",
"# Pleasantly, 42 is a nice face-on spiral\n",
"indx = 92\n",
"\n",
"query_results = collection.query(\n",
" query_embeddings=[a[indx]],\n",
" n_results=5,\n",
")\n",
"\n",
"print(query_results[\"distances\"])\n",
"\n",
"metadatas = query_results[\"metadatas\"]\n",
"\n",
"files_to_plot = []\n",
"for m in metadatas[0]:\n",
" files = prepped_output.container.files[int(m[\"filename\"])]\n",
" g_file = files[\"HSC-G\"]\n",
" files_to_plot.append(g_file[:-11])\n",
"\n",
"for i, file_name in enumerate(files_to_plot):\n",
" plotter(file_name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's look for outliers. For every entry in the database find a number to represent the distance to it's nearest neighbor.\n",
"\n",
"For instance, it could be the distance to it's closest neighbor, the mean (or in this case median) distance to it's closest N neighbors, etc...\n",
"\n",
"Note that this is an inefficient way to query the database - Chromadb recommends batching the queries."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import glob\n",
"\n",
"found_files = glob.glob(\"/home/drew/code/fibad/docs/notebooks/results/20241216-203830-predict/*.npy\")\n",
"\n",
"distances = []\n",
"file_names = []\n",
"\n",
"# for each embedding in each output file from inference, calculate a representation of the distance to it's nearest neighbor\n",
"for f in found_files:\n",
" a = np.load(f)\n",
" for i in range(len(a)):\n",
" query_results = collection.query(\n",
" query_embeddings=[a[i]],\n",
" n_results=10,\n",
" )\n",
" distances.append(np.median(query_results[\"distances\"][0][1:]))\n",
" file_names.append(query_results[\"metadatas\"][0][0][\"filename\"])\n",
"\n",
"# print some statistics about the distances\n",
"print(f\"Total values: {len(distances)}\")\n",
"print(f\"Max: {max(distances)} Min: {min(distances)}\")\n",
"print(f\"Mean: {np.mean(distances)} Median: {np.median(distances)} Std: {np.std(distances)}\")\n",
"\n",
"# create a histogram of the distances\n",
"_ = plt.hist(distances, bins=50, range=(0, 3500))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So let's look more closely at the objects in the tail of the histogram. i.e. the ones that are \"far\" from their nearest neighbors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get the indexes of distances where the value is between\n",
"# Near by range: 2750 and 3250\n",
"# Far away range: 5_000 and 30_000\n",
"indexes = [i for i, x in enumerate(distances) if 5_000 < x < 30_000]\n",
"\n",
"# use those indexes to get the file names from the file_names list\n",
"files_to_plot = [file_names[i] for i in indexes]\n",
"files_to_plot\n",
"\n",
"\n",
"# plot the images that are in the range of distances we specified.\n",
"plot_em = []\n",
"names = []\n",
"for m in files_to_plot:\n",
" files = prepped_output.container.files[int(m)]\n",
" g_file = files[\"HSC-G\"]\n",
" plot_em.append(g_file[:-11])\n",
" names.append(g_file[:-11])\n",
"\n",
"for file_name, name in zip(plot_em, names):\n",
" plotter(file_name)\n",
" print(name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"found_files = glob.glob(\"/home/drew/code/fibad/docs/notebooks/results/20241216-203830-predict/*.npy\")\n",
"\n",
"distances = []\n",
"file_names = []\n",
"latent_spaces = []\n",
"\n",
"# for each embedding in each output file from inference, calculate a representation of the distance to it's nearest neighbor\n",
"for f in found_files:\n",
" latent_spaces.append(np.load(f))\n",
"\n",
"latent_spaces = np.asarray(np.concatenate(latent_spaces))\n",
"latent_spaces.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import umap\n",
"\n",
"reducer = umap.UMAP()\n",
"embedding = reducer.fit_transform(latent_spaces)\n",
"embedding.shape\n",
"plt.scatter(embedding[:, 0], embedding[:, 1], s=1)"
]
}
],
"metadata": {
Expand All @@ -119,7 +400,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
"version": "3.12.8"
}
},
"nbformat": 4,
Expand Down
Loading