From 9f4f0d20cd5526d77f598e19a1332eaa9a7a976f Mon Sep 17 00:00:00 2001 From: EC2 Default User Date: Mon, 6 Sep 2021 10:05:09 -0400 Subject: [PATCH] #27 enabled shap plots --- clin_msi/msi_model_scripts/apply_msi_model.py | 16 ++++----- clin_msi/msi_training.py | 35 ------------------- clin_msi/predict.py | 8 ++--- 3 files changed, 12 insertions(+), 47 deletions(-) delete mode 100644 clin_msi/msi_training.py diff --git a/clin_msi/msi_model_scripts/apply_msi_model.py b/clin_msi/msi_model_scripts/apply_msi_model.py index 6bbf162..521e46a 100644 --- a/clin_msi/msi_model_scripts/apply_msi_model.py +++ b/clin_msi/msi_model_scripts/apply_msi_model.py @@ -130,11 +130,11 @@ def apply_model(infile,moddir,outfile,shap_plot_dir=None): df=df.rename(columns={sampcol:'samp'}) dfnew,shapdict,curfeats=apply_mod_to_dataframe(df,moddir) dfnew[['samp','yprob']].to_csv(outfile,index=False) - # ## GRAB SHAP DATA - # testdat,bmdict=grab_shap_data(dfnew,shapdict,curfeats) - # ## EXPORT SHAP PLOTS - # for i in range(len(testdat)): - # shap_outfile=shap_plot_dir+"/shap_plot_"+str(i+1)+'.png' - # shaprec=testdat.iloc[i] - # shapdict_loc=bmdict[i] - # build_n_save_shap_plot(shaprec,shapdict_loc,shap_outfile) + ## GRAB SHAP DATA + testdat,bmdict=grab_shap_data(dfnew,shapdict,curfeats) + ## EXPORT SHAP PLOTS + for i in range(len(testdat)): + shap_outfile=shap_plot_dir+"/shap_plot_"+str(i+1)+'.png' + shaprec=testdat.iloc[i] + shapdict_loc=bmdict[i] + build_n_save_shap_plot(shaprec,shapdict_loc,shap_outfile) diff --git a/clin_msi/msi_training.py b/clin_msi/msi_training.py deleted file mode 100644 index a9c66da..0000000 --- a/clin_msi/msi_training.py +++ /dev/null @@ -1,35 +0,0 @@ -import pickle -import logging - -import xgboost as xgb - -## FUNCTIONS -## INTERNAL TRAINING FUNCTION -def modfit(df,curfeats): - X=df[curfeats] - Y=df['y'] - model = xgb.XGBClassifier(max_depth=1, - colsample_bytree=0.001, - n_estimators=500, - eval_metric='error', - use_label_encoder=False ## to prevent a warning - ) - model.fit(X,Y) - return model - -## MAIN FUNCTION -def train_models(df,moddir): - curfeats=[x for x in df if x != 'y'] - moddict={} - nrun=500 - modfiles=[moddir+'/xgb_'+str(j+1)+'.pkl' for j in range(nrun)] - for j in range(nrun): - if (j % 10 == 9): - ## DINKY PROGRESS REPORT - logging.info("Training model "+str(j+1)) - traininds_loc=df.sample(frac=0.8).index - mymod=modfit(df.loc[traininds_loc],curfeats) - pkl_connect = open(modfiles[j], 'wb') - pickle.dump(mymod, pkl_connect) - pkl_connect.close() - return(modfiles) diff --git a/clin_msi/predict.py b/clin_msi/predict.py index 02849d9..362eccb 100644 --- a/clin_msi/predict.py +++ b/clin_msi/predict.py @@ -6,8 +6,8 @@ import pysam import pandas as pd -from .count_normalization.normalize_counts import parse_raw_data -from .msi_model_scripts.apply_msi_model import apply_model +from count_normalization.normalize_counts import parse_raw_data +from msi_model_scripts.apply_msi_model import apply_model def repeat_finder(s): @@ -96,7 +96,7 @@ def predict( #apply model to normalized msi counts final_results_file = os.path.join(output_dir, sample_name + '_MSIscore.txt') - apply_model(os.path.join(output_dir, sample_name + '_normalized.csv'), model_dir, final_results_file) + apply_model(os.path.join(output_dir, sample_name + '_normalized.csv'), model_dir, final_results_file,shap_plot_dir=output_dir) if __name__ == '__main__': - predict() \ No newline at end of file + predict()