diff --git a/docs/notebooks/run_full_sim.ipynb b/docs/notebooks/run_full_sim.ipynb new file mode 100644 index 0000000..695f73b --- /dev/null +++ b/docs/notebooks/run_full_sim.ipynb @@ -0,0 +1,877 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "247291b3-da31-4eef-a89e-0360ec300bdb", + "metadata": {}, + "outputs": [], + "source": [ + "import resspect\n", + "import pandas as pd\n", + "from resspect import request_TOM_data\n", + "from resspect import fit_TOM, fit\n", + "from resspect import submit_queries_to_TOM\n", + "from resspect import time_domain_loop\n", + "from resspect.tom_client import TomClient\n", + "from resspect import time_domain_loop\n", + "from resspect import TimeDomainConfiguration\n", + "import os\n", + "import re\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30f96d54-c60e-4ce5-b150-eb6b57c50e0b", + "metadata": {}, + "outputs": [], + "source": [ + "###🔲 Need to import this from RESSPECT\n", + "###🔲 Need to put this updated version in RESSPECT \n", + "def update_pool_stash(day: int):\n", + " #🔲check if a directory exists to contain features and if it doesn't, make it!\n", + " outdir = 'TOM_days_storage'\n", + "\n", + " #should we store old features somewhere? makes it easier to add training objs\n", + " #would want to add current MJD, maybe first MJD, and peak MJD\n", + " if day!=0:\n", + " current_stash_path = outdir+'/TOM_compiled_features_day'+str(day-1)+'.csv'\n", + " elif day==0:\n", + " current_stash_path = outdir+'/TOM_hot_features_day_'+str(day)+'.csv'\n", + " \n", + " new_night_path = outdir+'/TOM_hot_features_day_'+str(day)+'.csv'\n", + " \n", + " #read in current stash as list of strings\n", + " with open(current_stash_path, 'r') as f:\n", + " current_stash = f.readlines()\n", + " #read in new night as list of strings\n", + " with open(new_night_path, 'r') as f:\n", + " new_night = f.readlines()\n", + "\n", + " curent_stash_df = pd.read_csv(current_stash_path)\n", + " new_night_df = pd.read_csv(new_night_path)\n", + "\n", + " compiled_df = pd.concat([curent_stash_df,new_night_df]).drop_duplicates('id', keep='last')\n", + " compiled_list = compiled_df.to_string(index=False).split('\\n')\n", + " compiled_comsep_list = [','.join(ele.split()) for ele in compiled_list]\n", + " return_string = '\\n'.join(compiled_comsep_list)\n", + "\n", + " output_path = outdir+'/TOM_compiled_features_day'+str(day)+'.csv'\n", + " # rewrite the file \n", + " with open(output_path, 'w') as f:\n", + " f.write(return_string)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01223312-9b79-4355-817b-fe98497c23de", + "metadata": {}, + "outputs": [], + "source": [ + "#🔲 Need to remove no longer hot SN, will help speed things up\n", + "#🔲 in the mean time remove after 15 days\n", + "def remove_from_pool_stash(day):\n", + " #🔲check if a directory exists to contain features and if it doesn't, make it!\n", + " outdir = 'TOM_days_storage'\n", + " current_stash_path = outdir+'/TOM_hot_features_day_'+str(day)+'.csv'\n", + " \n", + " #read in current stash as list of strings\n", + " with open(current_stash_path, 'r') as f:\n", + " current_stash = f.readlines()\n", + " \n", + " curent_stash_df = pd.read_csv(current_stash_path)\n", + "\n", + "\n", + " #🔲 Need to remove old obj\n", + " remove_old_obj_df = current_stash_df[current_stash_df[\"date_added\"] > day-15]\n", + "\n", + " removed_list = remove_old_obj_df.to_string(index=False).split('\\n')\n", + " removed_comsep_list = [','.join(ele.split()) for ele in removed_list]\n", + " return_string = '\\n'.join(removed_comsep_list)\n", + "\n", + " output_path = outdir+'/TOM_compiled_features_day'+str(day)+'.csv'\n", + " #gotta rewrite the file dummy\n", + " with open(output_path, 'w') as f:\n", + " f.write(return_string)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c24acd97", + "metadata": {}, + "outputs": [], + "source": [ + "additional_info = [\n", + " 'hostgal_snsep',\n", + " 'hostgal_ellipticity',\n", + " 'hostgal_sqradius',\n", + " 'hostgal_mag_u',\n", + " 'hostgal_mag_g',\n", + " 'hostgal_mag_r',\n", + " 'hostgal_mag_i',\n", + " 'hostgal_mag_z',\n", + " 'hostgal_mag_y',\n", + " 'hostgal_magerr_u',\n", + " 'hostgal_magerr_g',\n", + " 'hostgal_magerr_r',\n", + " 'hostgal_magerr_i',\n", + " 'hostgal_magerr_z',\n", + " 'hostgal_magerr_y',\n", + " ]\n", + "\n", + "from laiss_resspect_classifier.elasticc2_laiss_feature_extractor import Elasticc2LaissFeatureExtractor\n", + "\n", + "def validate_objects(objects_to_test):\n", + " fe = Elasticc2LaissFeatureExtractor()\n", + " good_objs = []\n", + "\n", + " for t_obj in objects_to_test:\n", + " fe.photometry= pd.DataFrame(t_obj['photometry'])\n", + " fe.id = t_obj['objectid']\n", + "\n", + " fe.additional_info = {}\n", + " for info in additional_info:\n", + " fe.additional_info[info] = t_obj[info]\n", + "\n", + " res = fe.fit_all()\n", + " if res:\n", + " good_objs.append(t_obj)\n", + "\n", + " return good_objs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd30be07-410d-44f3-a878-dfed0be1fafb", + "metadata": {}, + "outputs": [], + "source": [ + "def get_phot(obj_df):\n", + "\n", + " tom = TomClient(url = \"https://desc-tom-2.lbl.gov\", username = 'awoldag', passwordfile = '../../../password.txt')\n", + "\n", + " # get all of the photometry at once\n", + " ids = obj_df['diaobject_id'].tolist()\n", + " res = tom.post('db/runsqlquery/',\n", + " json={ 'query': 'SELECT diaobject_id, filtername, midpointtai, psflux, psfluxerr' \n", + " ' FROM elasticc2_ppdbdiaforcedsource' \n", + " ' WHERE diaobject_id IN (%s) ORDER BY diaobject_id, filtername, midpointtai;' % (', '.join(str(id) for id in ids)),\n", + " 'subdict': {} } )\n", + " all_phot = res.json()['rows']\n", + " all_phot_df = pd.DataFrame(all_phot)\n", + " # if you need mag from the arbitrary flux-\n", + " all_phot_df['mag'] = -2.5*np.log10(all_phot_df['psflux']) + 27.5\n", + " all_phot_df['magerr'] = 2.5/np.log(10) * all_phot_df['psfluxerr']/all_phot_df['psflux']\n", + "\n", + " #! Need to send Rob a message to ask that these features be included when querying for hot super nova\n", + " host_res = tom.post('db/runsqlquery/',\n", + " json={ 'query': 'SELECT diaobject_id, hostgal_mag_u, hostgal_mag_g, hostgal_mag_r, hostgal_mag_i, hostgal_mag_z, hostgal_mag_Y, hostgal_magerr_u, hostgal_magerr_g, hostgal_magerr_r, hostgal_magerr_i, hostgal_magerr_z, hostgal_magerr_Y, hostgal_snsep, hostgal_ellipticity, hostgal_sqradius'\n", + " ' FROM elasticc2_ppdbdiaobject'\n", + " ' WHERE diaobject_id IN (%s) ORDER BY diaobject_id;' % (', '.join(str(id) for id in ids)),\n", + " 'subdict': {} } )\n", + " all_host = host_res.json()['rows']\n", + "\n", + "\n", + " # format into a list of dicts\n", + " data = []\n", + " for idx, obj in obj_df.iterrows():\n", + " phot = all_phot_df[all_phot_df['diaobject_id'] == obj['diaobject_id']]\n", + "\n", + " phot_d = {}\n", + " phot_d['objectid'] = int(obj['diaobject_id'])\n", + " phot_d['sncode'] = int(obj['gentype'])\n", + " phot_d['redshift'] = obj['zcmb']\n", + " phot_d['ra'] = obj['ra']\n", + " phot_d['dec'] = obj['dec']\n", + " phot_d['photometry'] = phot[['filtername', 'midpointtai', 'psflux', 'psfluxerr', 'mag', 'magerr']].to_dict(orient='list')\n", + "\n", + " phot_d['photometry']['band'] = phot_d['photometry']['filtername']\n", + " phot_d['photometry']['mjd'] = phot_d['photometry']['midpointtai']\n", + " phot_d['photometry']['flux'] = phot_d['photometry']['psflux']\n", + " phot_d['photometry']['fluxerr'] = phot_d['photometry']['psfluxerr']\n", + " phot_d['photometry']['mag'] = phot_d['photometry']['mag']\n", + " phot_d['photometry']['magerr'] = phot_d['photometry']['magerr']\n", + " del phot_d['photometry']['filtername']\n", + " del phot_d['photometry']['midpointtai']\n", + " del phot_d['photometry']['psflux']\n", + " del phot_d['photometry']['psfluxerr']\n", + " phot_d = {**phot_d, **all_host[idx]}\n", + " del phot_d['diaobject_id']\n", + " data.append(phot_d)\n", + "\n", + " return data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7da5d212-81e7-43d8-8592-f9e8ccae00fd", + "metadata": {}, + "outputs": [], + "source": [ + "def get_phot_orig(obj_df):\n", + " # get all of the photometry at once\n", + " ids = obj_df['diaobject_id'].tolist()\n", + " res = tom.post('db/runsqlquery/',\n", + " json={ 'query': 'SELECT diaobject_id, filtername, midpointtai, psflux, psfluxerr' \n", + " ' FROM elasticc2_ppdbdiaforcedsource' \n", + " ' WHERE diaobject_id IN (%s) ORDER BY diaobject_id, filtername, midpointtai;' % (', '.join(str(id) for id in ids)),\n", + " 'subdict': {} } )\n", + " all_phot = res.json()['rows']\n", + " all_phot_df = pd.DataFrame(all_phot)\n", + " # if you need mag from the arbitrary flux\n", + " # all_phot_df['mag'] = -2.5*np.log10(all_phot_df['psflux']) + 27.5\n", + " # all_phot_df['magerr'] = 2.5/np.log(10) * all_phot_df['psfluxerr']/all_phot_df['psflux']\n", + "\n", + " # format into a list of dicts\n", + " data = []\n", + " for idx, obj in obj_df.iterrows():\n", + " phot = all_phot_df[all_phot_df['diaobject_id'] == obj['diaobject_id']]\n", + " \n", + " phot_d = {}\n", + " phot_d['objectid'] = int(obj['diaobject_id'])\n", + " phot_d['sncode'] = int(obj['gentype'])\n", + " phot_d['redshift'] = obj['zcmb']\n", + " phot_d['ra'] = obj['ra']\n", + " phot_d['dec'] = obj['dec']\n", + " phot_d['photometry'] = phot[['filtername', 'midpointtai', 'psflux', 'psfluxerr']].to_dict(orient='list')\n", + "\n", + " phot_d['photometry']['band'] = phot_d['photometry']['filtername']\n", + " phot_d['photometry']['mjd'] = phot_d['photometry']['midpointtai']\n", + " phot_d['photometry']['flux'] = phot_d['photometry']['psflux']\n", + " phot_d['photometry']['fluxerr'] = phot_d['photometry']['psfluxerr']\n", + " del phot_d['photometry']['filtername']\n", + " del phot_d['photometry']['midpointtai']\n", + " del phot_d['photometry']['psflux']\n", + " del phot_d['photometry']['psfluxerr']\n", + " \n", + " data.append(phot_d)\n", + "\n", + " return data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cae4fd9b-6866-40ce-a536-8c50c2337c6e", + "metadata": {}, + "outputs": [], + "source": [ + "#MAKE INITIAL TRAINING SET \n", + "objs = []\n", + "\n", + "tom = TomClient(url = \"https://desc-tom-2.lbl.gov\", username = 'awoldag', passwordfile = '../../../password.txt')\n", + "\n", + "res = tom.post('db/runsqlquery/',\n", + " json={ 'query': 'SELECT diaobject_id, gentype, zcmb, peakmjd,' \n", + " ' peakmag_g, ra, dec FROM elasticc2_diaobjecttruth WHERE peakmjd>61300 and peakmjd<61309 and gentype=10 limit 10;', \n", + " 'subdict': {}} )\n", + "objs.extend(res.json()['rows'])\n", + "gentypes = [20,21,25,26,27,12,40,42,59]\n", + "for gentype in gentypes:\n", + " res = tom.post('db/runsqlquery/',\n", + " json={ 'query': 'SELECT diaobject_id, gentype, zcmb, peakmjd,' \n", + " f' peakmag_g, ra, dec FROM elasticc2_diaobjecttruth WHERE peakmjd>61300 and peakmjd<61309 and gentype={gentype} limit 5;', \n", + " 'subdict': {}} )\n", + " objs.extend(res.json()['rows'])\n", + "\n", + "# res = tom.post('db/runsqlquery/',\n", + "# json={ 'query': 'SELECT diaobject_id, gentype, zcmb, peakmjd,' \n", + "# ' peakmag_g, ra, dec FROM elasticc2_diaobjecttruth WHERE peakmjd>61300 and peakmjd<61309 and gentype=31 limit 5;', \n", + "# 'subdict': {}} )\n", + "# objs.extend(res.json()['rows'])\n", + "\n", + "training_objs = get_phot(pd.DataFrame(objs))\n", + "good_objs = validate_objects(training_objs)\n", + "\n", + "outdir = 'TOM_days_storage'\n", + "\n", + "if not os.path.exists(outdir):\n", + " os.makedirs(outdir)\n", + "\n", + "# 🔲 change this to fit()\n", + "feature_extraction_method = 'laiss_resspect_classifier.elasticc2_laiss_feature_extractor.Elasticc2LaissFeatureExtractor'\n", + "fit(\n", + " good_objs,\n", + " output_features_file = outdir+'/TOM_training_features',\n", + " feature_extractor = feature_extraction_method,\n", + " filters='ZTF',\n", + " additional_info=additional_info,\n", + " one_code=gentypes,\n", + ")\n", + "data = pd.read_csv('TOM_days_storage/TOM_training_features',index_col=False)\n", + "data['orig_sample'] = 'train'\n", + "data[\"type\"] = np.where(data[\"sncode\"] == 10, 'Ia', 'other')\n", + "data.to_csv('TOM_days_storage/TOM_training_features',index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52da9afa-5b1f-4c66-80b6-77608dcd30ac", + "metadata": {}, + "outputs": [], + "source": [ + "#MAKE TEST SET \n", + "objs = []\n", + "\n", + "tom = TomClient(url = \"https://desc-tom-2.lbl.gov\", username = 'awoldag', passwordfile = '../../../password.txt')\n", + "\n", + "res = tom.post('db/runsqlquery/',\n", + " json={ 'query': 'SELECT diaobject_id, gentype, zcmb, peakmjd,' \n", + " ' peakmag_g, ra, dec FROM elasticc2_diaobjecttruth WHERE peakmjd>61310 and peakmjd<61339 and gentype=10 limit 1000;', \n", + " 'subdict': {}} )\n", + "objs.extend(res.json()['rows'])\n", + "\n", + "gentypes = [20,21,25,26,27,12,40,42,59]\n", + "for gentype in gentypes:\n", + " res = tom.post('db/runsqlquery/',\n", + " json={ 'query': 'SELECT diaobject_id, gentype, zcmb, peakmjd,' \n", + " f' peakmag_g, ra, dec FROM elasticc2_diaobjecttruth WHERE peakmjd>61310 and peakmjd<61339 and gentype = {gentype} limit 100;', \n", + " 'subdict': {}} )\n", + " objs.extend(res.json()['rows'])\n", + "# res = tom.post('db/runsqlquery/',\n", + "# json={ 'query': 'SELECT diaobject_id, gentype, zcmb, peakmjd,' \n", + "# ' peakmag_g, ra, dec FROM elasticc2_diaobjecttruth WHERE peakmjd>61310 and peakmjd<61339 and gentype=31 limit 500;', \n", + "# 'subdict': {}} )\n", + "# objs.extend(res.json()['rows'])\n", + "\n", + "test_objs = get_phot(pd.DataFrame(objs))\n", + "good_objs = validate_objects(test_objs)\n", + "\n", + "outdir = 'TOM_days_storage'\n", + "\n", + "if not os.path.exists(outdir):\n", + " os.makedirs(outdir)\n", + "\n", + "# 🔲 change this to fit()\n", + "feature_extraction_method = 'laiss_resspect_classifier.elasticc2_laiss_feature_extractor.Elasticc2LaissFeatureExtractor'\n", + "fit(\n", + " good_objs,\n", + " output_features_file = outdir+'/TOM_testing_features',\n", + " feature_extractor = feature_extraction_method,\n", + " filters='ZTF',\n", + " additional_info=additional_info,\n", + " one_code=gentypes,\n", + ")\n", + "data = pd.read_csv('TOM_days_storage/TOM_testing_features',index_col=False)\n", + "data['orig_sample'] = 'train'\n", + "data[\"type\"] = np.where(data[\"sncode\"] == 10, 'Ia', 'other')\n", + "data.to_csv('TOM_days_storage/TOM_testing_features',index=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f4ff59b-e018-4777-a709-736b98fb8e43", + "metadata": {}, + "outputs": [], + "source": [ + "#✅ MAKE VALIDATION SET \n", + "objs = []\n", + "\n", + "tom = TomClient(url = \"https://desc-tom-2.lbl.gov\", username = 'awoldag', passwordfile = '../../../password.txt')\n", + "\n", + "res = tom.post('db/runsqlquery/',\n", + " json={ 'query': 'SELECT diaobject_id, gentype, zcmb, peakmjd,' \n", + " ' peakmag_g, ra, dec FROM elasticc2_diaobjecttruth WHERE peakmjd>61340 and gentype=10 limit 1000;', \n", + " 'subdict': {}} )\n", + "objs.extend(res.json()['rows'])\n", + "\n", + "gentypes = [20,21,25,26,27,12,40,42,59]\n", + "for gentype in gentypes:\n", + " res = tom.post('db/runsqlquery/',\n", + " json={ 'query': 'SELECT diaobject_id, gentype, zcmb, peakmjd,' \n", + " f' peakmag_g, ra, dec FROM elasticc2_diaobjecttruth WHERE peakmjd>61340 and gentype={gentype} limit 100;', \n", + " 'subdict': {}} )\n", + " objs.extend(res.json()['rows'])\n", + "\n", + "# res = tom.post('db/runsqlquery/',\n", + "# json={ 'query': 'SELECT diaobject_id, gentype, zcmb, peakmjd,' \n", + "# ' peakmag_g, ra, dec FROM elasticc2_diaobjecttruth WHERE peakmjd>61340 and gentype=31 limit 500;', \n", + "# 'subdict': {}} )\n", + "# objs.extend(res.json()['rows'])\n", + "\n", + "val_objs = get_phot(pd.DataFrame(objs))\n", + "good_objs = validate_objects(val_objs)\n", + "\n", + "outdir = 'TOM_days_storage'\n", + "\n", + "if not os.path.exists(outdir):\n", + " os.makedirs(outdir)\n", + "\n", + "# 🔲 change this to fit()\n", + "feature_extraction_method = 'laiss_resspect_classifier.elasticc2_laiss_feature_extractor.Elasticc2LaissFeatureExtractor'\n", + "fit(\n", + " good_objs,\n", + " output_features_file = outdir+'/TOM_validation_features',\n", + " feature_extractor = feature_extraction_method,\n", + " filters='ZTF',\n", + " additional_info=additional_info,\n", + " one_code=gentypes,\n", + ")\n", + "data = pd.read_csv('TOM_days_storage/TOM_validation_features',index_col=False)\n", + "data['orig_sample'] = 'train'\n", + "data[\"type\"] = np.where(data[\"sncode\"] == 10, 'Ia', 'other')\n", + "data.to_csv('TOM_days_storage/TOM_validation_features',index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dcadd173-4778-406c-80e4-4dccdddc8365", + "metadata": {}, + "outputs": [], + "source": [ + "def run_one_night(day): \n", + " #🔲 check for new spec+classification for training set before running the loop\n", + " #✅flag for simulated vs real data\n", + " #🔲implement an auto-check\n", + " #✅first thing, check if it's real or sim data\n", + " #🔲MAKE CLEAR WHAT IS FOR REAL DATA VS SIM (eg. real data will want date)\n", + " #✅request light curve data from the TOM - for real and simulated\n", + " \n", + " #get new lc info from TOM (from yesterday (for now))\n", + " # data_dic = request_TOM_data(url = \"https://desc-tom-2.lbl.gov\", username='awoldag',\n", + " # passwordfile='../../../password.txt', detected_in_last_days = 1,\n", + " # mjdnow = 60796+day-1, cheat_gentypes = [82, 10, 21, 27, 26, 37, 32, 36, 31, 89])\n", + " data_dic = request_TOM_data(url = \"https://desc-tom-2.lbl.gov\", username='awoldag',\n", + " passwordfile='../../../password.txt', detected_in_last_days = 1,\n", + " mjdnow = 60796+day-1, cheat_gentypes = [20,21,25,26,27,12,40,42,59,10,30,31,32,35,36,37,11])\n", + " \n", + " data_dic=data_dic['diaobject']\n", + "\n", + " good_dic = validate_objects(data_dic)\n", + "\n", + " # feature_extraction_method = 'Malanchev'\n", + " feature_extraction_method = 'laiss_resspect_classifier.elasticc2_laiss_feature_extractor.Elasticc2LaissFeatureExtractor'\n", + " # classifier = 'RandomForest'\n", + " classifier = 'laiss_resspect_classifier.laiss_classifier.LaissRandomForest'\n", + "\n", + " #✅run that data through RESSPECT to get features\n", + " #🔲at some point, cut out objects that are not likely SN - do this before it gets to RESSPECT probably\n", + " #✅clarify out file argument/data base\n", + "\n", + "\n", + " # if day >= 0:\n", + "# file_name = outdir+'/TOM_hot_features_day_'+str(day)+'.csv'\n", + "# else:\n", + "# file_name = outdir+'/TOM_hot_features.csv'\n", + " \n", + " #get features from that data\n", + " outdir = 'TOM_days_storage'\n", + " file_name = outdir+'/TOM_hot_features_day_'+str(day-1)+'.csv'\n", + " \n", + " fit(\n", + " good_dic,\n", + " output_features_file = file_name,\n", + " feature_extractor = feature_extraction_method,\n", + " filters='ZTF',\n", + " additional_info=additional_info\n", + " )\n", + " data = pd.read_csv(file_name, index_col=False)\n", + " data['orig_sample'] = 'pool'\n", + " #add date added so that we can remove when they are too old\n", + "\n", + " \n", + " data.to_csv(file_name,index=False)\n", + "\n", + " # -------------------------\n", + "\n", + " #get new lc info from TOM (for today)\n", + " # data_dic = request_TOM_data(url = \"https://desc-tom-2.lbl.gov\",username='awoldag',\n", + " # passwordfile='../../../password.txt',detected_in_last_days = 1, mjdnow = 60796+day, \n", + " # cheat_gentypes = [82, 10, 21, 27, 26, 37, 32, 36, 31, 89])\n", + " data_dic = request_TOM_data(url = \"https://desc-tom-2.lbl.gov\",username='awoldag',\n", + " passwordfile='../../../password.txt',detected_in_last_days = 1, mjdnow = 60796+day, \n", + " cheat_gentypes = [20,21,25,26,27,12,40,42,59,10,30,31,32,35,36,37,11])\n", + " data_dic=data_dic['diaobject']\n", + " good_dic = validate_objects(data_dic)\n", + " #get features from that data\n", + " outdir = 'TOM_days_storage'\n", + " file_name = outdir+'/TOM_hot_features_day_'+str(day)+'.csv'\n", + " \n", + " fit(\n", + " good_dic,\n", + " output_features_file = file_name,\n", + " feature_extractor = feature_extraction_method,\n", + " filters='ZTF',\n", + " additional_info=additional_info,\n", + " )\n", + " data = pd.read_csv(file_name, index_col=False)\n", + " data['orig_sample'] = 'pool'\n", + " \n", + " data.to_csv(file_name,index=False)\n", + " \n", + "\n", + " #✅update feature lists\n", + " #✅ change this so that we have a new file for each day - just make it so that update_pool_stash writes to a new file \n", + " #✅(and puts this file in a directory)\n", + " #🔲 Probably want to make this more general file names down the line.....\n", + "# update_pool_stash(day)\n", + "\n", + "#! #############################################\n", + "#! Moved everything below here to a new function\n", + "#! #############################################\n", + "\n", + " # # run the loop to get queried objects and updated metrics\n", + " # days = [day-1, day+1] # first and last day of the survey\n", + " # training = None # if int take int number of objs\n", + " # # for initial training, 50% being Ia\n", + " \n", + " # strategy = 'UncSampling' # learning strategy\n", + " # batch = 5 # if int, ignore cost per observation,\n", + " # # if None find optimal batch size\n", + " \n", + " # sep_files = True # if True, expects train, test and\n", + " # # validation samples in separate filess\n", + " \n", + " # path_to_features_dir = 'TOM_days_storage/' # folder where the files for each day are stored\n", + " \n", + " # # output results for metrics\n", + " # output_metrics_file = 'results/metrics_' + strategy + '_' + str('ini_train_set') + \\\n", + " # '_batch' + str(batch) + '.csv'\n", + " \n", + " # # output query sample\n", + " # output_query_file = 'results/queried_' + strategy + '_' + str('ini_train_set') + \\\n", + " # '_batch' + str(batch) + '_day_'+ str(day) + '.csv'\n", + " \n", + " # path_to_ini_files = {}\n", + " \n", + " # # features from full light curves for initial training sample\n", + " # path_to_ini_files['train'] = 'TOM_days_storage/TOM_training_features'\n", + " # path_to_ini_files['test'] = 'TOM_days_storage/TOM_testing_features'\n", + " # path_to_ini_files['validation'] = 'TOM_days_storage/TOM_validation_features'\n", + " \n", + " # survey='LSST'\n", + " \n", + "\n", + " # n_estimators = 1000 # number of trees in the forest\n", + " \n", + "\n", + " # screen = False # if True will print many things for debuging\n", + " # fname_pattern = ['TOM_hot_features_day_', '.csv'] # pattern on filename where different days\n", + " # # are stored\n", + " \n", + " # queryable= False # if True, check brightness before considering\n", + " # # an object queryable \n", + " \n", + " # # run time domain loop\n", + " # time_domain_loop(TimeDomainConfiguration(days=days, output_metrics_file=output_metrics_file,\n", + " # output_queried_file=output_query_file,\n", + " # path_to_ini_files=path_to_ini_files,\n", + " # path_to_features_dir=path_to_features_dir,\n", + " # strategy=strategy, fname_pattern=fname_pattern, batch=batch,\n", + " # classifier=classifier,\n", + " # sep_files=sep_files,\n", + " # survey=survey, queryable=queryable,\n", + " # feature_extraction_method=feature_extraction_method),\n", + " # screen=screen, n_estimators=n_estimators,\n", + " # )\n", + "\n", + " # #🔲 do we want higher entropy in our returned objects?\n", + " # # Read in RESSPECT requests to input to TOM format\n", + " # ids = list(pd.read_csv(output_query_file)['id'])\n", + " # ids = [int(id) for id in ids]\n", + " # num = int(len(ids)/5)\n", + " # mod = len(ids)%5\n", + " # num_list = [num]*5\n", + " # mod_list = []\n", + " # for i in range(mod):\n", + " # mod_list.append(1)\n", + " # rem = 5-len(mod_list)\n", + " # mod_list = mod_list+[0]*rem\n", + " # num_list=list(np.asarray(num_list)+mod_list)\n", + " # priorities = []\n", + " # priorities.append([1]*num_list[0]+[2]*num_list[1]+[3]*num_list[2]+[4]*num_list[3]+[5]*num_list[4]) \n", + " # priorities = priorities[0]\n", + " \n", + " # # send these queried objects to the TOM\n", + " # # submit_queries_to_TOM('awoldag', '../../../password.txt', objectids = ids, priorities = priorities, requester = 'resspect')\n", + " # print(ids, priorities)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9aabfe56-4141-4460-9fcb-ccb1bdec13b0", + "metadata": {}, + "outputs": [], + "source": [ + "run_one_night(50)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea30c097", + "metadata": {}, + "outputs": [], + "source": [ + "def run_td_loop(day):\n", + " # feature_extraction_method = 'Malanchev'\n", + " feature_extraction_method = 'laiss_resspect_classifier.elasticc2_laiss_feature_extractor.Elasticc2LaissFeatureExtractor'\n", + " # classifier = 'RandomForest'\n", + " classifier = 'laiss_resspect_classifier.laiss_classifier.LaissRandomForest'\n", + " \n", + " # run the loop to get queried objects and updated metrics\n", + " days = [day-1, day+1] # first and last day of the survey\n", + " training = None # if int take int number of objs\n", + " # for initial training, 50% being Ia\n", + " \n", + " strategy = 'UncSampling' # learning strategy\n", + " batch = 5 # if int, ignore cost per observation,\n", + " # if None find optimal batch size\n", + " \n", + " sep_files = True # if True, expects train, test and\n", + " # validation samples in separate filess\n", + " \n", + " path_to_features_dir = 'TOM_days_storage/' # folder where the files for each day are stored\n", + " \n", + " # output results for metrics\n", + " output_metrics_file = 'results/metrics_' + strategy + '_' + str('ini_train_set') + \\\n", + " '_batch' + str(batch) + '.csv'\n", + " \n", + " # output query sample\n", + " output_query_file = 'results/queried_' + strategy + '_' + str('ini_train_set') + \\\n", + " '_batch' + str(batch) + '_day_'+ str(day) + '.csv'\n", + " \n", + " path_to_ini_files = {}\n", + " \n", + " # features from full light curves for initial training sample\n", + " path_to_ini_files['train'] = 'TOM_days_storage/TOM_training_features'\n", + " path_to_ini_files['test'] = 'TOM_days_storage/TOM_testing_features'\n", + " path_to_ini_files['validation'] = 'TOM_days_storage/TOM_validation_features'\n", + " \n", + " survey='ZTF'\n", + " \n", + "\n", + " n_estimators = 1000 # number of trees in the forest\n", + " \n", + "\n", + " screen = False # if True will print many things for debuging\n", + " fname_pattern = ['TOM_hot_features_day_', '.csv'] # pattern on filename where different days\n", + " # are stored\n", + " \n", + " queryable= False # if True, check brightness before considering\n", + " # an object queryable \n", + " \n", + " # run time domain loop\n", + " time_domain_loop(TimeDomainConfiguration(days=days, output_metrics_file=output_metrics_file,\n", + " output_queried_file=output_query_file,\n", + " path_to_ini_files=path_to_ini_files,\n", + " path_to_features_dir=path_to_features_dir,\n", + " strategy=strategy, fname_pattern=fname_pattern, batch=batch,\n", + " classifier=classifier,\n", + " sep_files=sep_files,\n", + " survey=survey, queryable=queryable,\n", + " feature_extraction_method=feature_extraction_method),\n", + " screen=screen, n_estimators=n_estimators,\n", + " )\n", + "\n", + " #🔲 do we want higher entropy in our returned objects?\n", + " # Read in RESSPECT requests to input to TOM format\n", + " ids = list(pd.read_csv(output_query_file)['objectid'])\n", + " ids = [int(id) for id in ids]\n", + " num = int(len(ids)/5)\n", + " mod = len(ids)%5\n", + " num_list = [num]*5\n", + " mod_list = []\n", + " for i in range(mod):\n", + " mod_list.append(1)\n", + " rem = 5-len(mod_list)\n", + " mod_list = mod_list+[0]*rem\n", + " num_list=list(np.asarray(num_list)+mod_list)\n", + " priorities = []\n", + " priorities.append([1]*num_list[0]+[2]*num_list[1]+[3]*num_list[2]+[4]*num_list[3]+[5]*num_list[4]) \n", + " priorities = priorities[0]\n", + " \n", + " # send these queried objects to the TOM\n", + " # submit_queries_to_TOM('awoldag', '../../../password.txt', objectids = ids, priorities = priorities, requester = 'resspect')\n", + " print(ids, priorities)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "012b8082", + "metadata": {}, + "outputs": [], + "source": [ + "run_td_loop(50)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00d8cba2-e04d-44e8-84f1-0905365fae94", + "metadata": {}, + "outputs": [], + "source": [ + "#Pull classified obj and add them to the training set\n", + "def get_classified(username, passwordfile=None, password=None, since = None):\n", + " tom = TomClient(url = \"https://desc-tom-2.lbl.gov\", username = username, password = password, \n", + " passwordfile = passwordfile)\n", + " dic = {}\n", + " if since is not None:\n", + " dic['since'] = since\n", + "\n", + " res = tom.post( 'elasticc2/getknownspectruminfo', json=dic )\n", + "\n", + " assert res.status_code == 200\n", + " assert res.json()['status'] == \"ok\"\n", + " reqs = res.json()\n", + " return reqs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f88f51ec-a375-4563-aaad-77582ea984ab", + "metadata": {}, + "outputs": [], + "source": [ + "classed_obj = get_classified('amandaw8', passwordfile='/Users/arw/secrets/TOM2', since = '11/22/2024 19:20:00')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04029ea7-337b-4efe-9f98-286487b35989", + "metadata": {}, + "outputs": [], + "source": [ + "objectids = []\n", + "classes = []\n", + "for obj in classed_obj['spectra']:\n", + " objectids.append(obj['objectid'])\n", + " if obj['classid'] == 2222:\n", + " classes.append('Ia')\n", + " else:\n", + " classes.append('other')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8af7dc4-d30b-4c5d-b2ec-4770bb2f4581", + "metadata": {}, + "outputs": [], + "source": [ + "def get_object_phot(username, passwordfile=None, password=None, obj_ids=[]):\n", + " tom = TomClient(url = \"https://desc-tom-2.lbl.gov\", username = username, password = password, \n", + " passwordfile = passwordfile)\n", + " dic = {'obj_ids': obj_ids}\n", + "\n", + " res = tom.post( 'elasticc2/getobjphot', json = dic)\n", + "\n", + " assert res.status_code == 200\n", + " assert res.json()['status'] == \"ok\"\n", + " reqs = res.json()\n", + " return reqs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6dcb4e8-ee8c-4ef4-ae6c-9d97ac6bd963", + "metadata": {}, + "outputs": [], + "source": [ + "def update_training_set(objectids, classes):\n", + " #need to fetch the current features of the labeled objs (probably from the TOM, get the features and format correctly) \n", + " #! figure out how to get a specific objid photometry\n", + " #IN THE FUTURE get this from our mongodb, in the mean time though only god knows which files will contain which objects\n", + "\n", + "\n", + " # call something like elasticc2/getobjphot\n", + " data_dic = get_object_phot(amandaw8, passwordfile = '/Users/arw/secrets/tom2', obj_ids = objectids)\n", + "\n", + " # put the ^ dictionary into the right format to get features\n", + " data_dic=data_dic['diaobject'] \n", + " \n", + " # then do something like fit_TOM to get the features from the object photometry\n", + " outdir = 'TOM_train_features_storage'\n", + " file_name = outdir+'/TOM_train_features_day_'+str(day)+'.csv'\n", + " \n", + " fit_TOM(data_dic, output_features_file = file_name, feature_extractor = 'Malanchev')\n", + " data = pd.read_csv(file_name, index_col=False)\n", + " data['orig_sample'] = 'train'\n", + " \n", + " data.to_csv(file_name,index=False)\n", + "\n", + " # then do something like read this file in and concatenate it with the current training set\n", + "\n", + "\n", + "\n", + " \n", + " #REMOVE classified objects from the pool set each night. Just double check that \n", + " #all SN in hot transients DO NOT have same object ids as those in the training set\n", + " #CHECK WITH ROB - can we make it so that gethotsne removed classified obj\n", + "\n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2932e2c0-7dc4-4344-bc51-1629487d4dde", + "metadata": {}, + "outputs": [], + "source": [ + "# add this classification and features to training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "398a6b2b-ddb7-48ee-ad4d-a6b238167c56", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07d46352-1c42-4b44-a0e1-1cf0d478b3e5", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0a6fc9a-e153-4e36-b2b8-d0cb486ec4a6", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "laiss_resspect", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/laiss_resspect_classifier/elasticc2_laiss_feature_extractor.py b/src/laiss_resspect_classifier/elasticc2_laiss_feature_extractor.py index 6ae647f..9463ea4 100644 --- a/src/laiss_resspect_classifier/elasticc2_laiss_feature_extractor.py +++ b/src/laiss_resspect_classifier/elasticc2_laiss_feature_extractor.py @@ -39,7 +39,7 @@ class Elasticc2LaissFeatureExtractor(LaissFeatureExtractor): id_column = "objectid" label_column = "sntype" - non_anomaly_classes = ["Normal"] # i.e. "Normal", "Ia", ... + non_anomaly_classes = ["Normal", "Ia"] # i.e. "Normal", "Ia", ... def __init__(self, **kwargs): super().__init__(**kwargs) @@ -48,7 +48,7 @@ def __init__(self, **kwargs): self.num_features = len(self.filters)*len(Elasticc2LaissFeatureExtractor.feature_names) + len(Elasticc2LaissFeatureExtractor.other_feature_names) + len(Elasticc2LaissFeatureExtractor.other_feature_names) @classmethod - def get_metadata_header(cls) -> list[str]: + def get_metadata_header(cls, **kwargs) -> list[str]: return [cls.id_column, "redshift", cls.label_column, "sncode", "sample"] @classmethod @@ -75,6 +75,8 @@ def fit_all(self) -> np.ndarray: laiss_features = ['None'] * self.num_features lightcurve = self.photometry + lightcurve['mag'] = -2.5*np.log10(lightcurve['flux']) + 27.5 + lightcurve['magerr'] = 2.5/np.log(10) * lightcurve['fluxerr']/lightcurve['flux'] min_obs_count = 4 _, property_names, _ = create_base_features_class(MAGN_EXTRACTOR, FLUX_EXTRACTOR) @@ -87,7 +89,7 @@ def fit_all(self) -> np.ndarray: print(f"Not enough obs for {self.id}. pass!\n") return - # extract lc features + # extract lc features t = np.array(detections['mjd']) m = np.array(detections['mag'], dtype=np.float64) merr = np.array(detections['magerr'], dtype=np.float64)