Skip to content

k-nakam/TI-estimator

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 

Repository files navigation

Treatment Ignorant Effect Estimation

Pytorch implementation of "Causal Estimation for Text Data with Apparent Overlap Violations".

Environment Setup

pip install -r requirements.txt

Then install the stable version (1.11.0) of PyTorch following the official guideline.

Description

Qmod

  • Q-Net:
    • Inputs of the system include
      • Text data $X$
      • A binary treatment variable $A$
      • A continuous outcome variable $Y$
      • (Optional) A binary confound $C$
    • Outputs
      • Estimated conditional outcome $Q_0=E(Y|A=0,X)$, $Q_1=E(Y|A=1,X)$
    • Other related functions
      • k_fold_fit_and_predict: train Q-Net in K-fold fashion
      • get_agg_q: train Q-Net with different seeds and get aggregated (average) conditional outcomes
  • Propensity estimation
    • Inputs
      • A binary treatment variable
      • Conditional outcomes: $Q_0=E(Y|A=0,X)$, $Q_1=E(Y|A=1,X)$
      • Choice of the nonparametric model
    • Output
      • Estimated propensity scors $g$
  • TI estimator
    • Inputs
      • A binary treatment variable $A$
      • A continuous outcome variable $Y$
      • Conditional outcomes: $Q_0=E(Y|A=0,X)$, $Q_1=E(Y|A=1,X)$
      • Estimated propensity scors $g$
      • Error bound for the confidence interval
    • Output
      • The TI estimator together with its uncertainty quantification and confidence interval
    • Other related functions:
      • get_estimands: get a list of causal estimators

simulation

  • Get simulated data:
    • run_simulation
      • data: raw data with treatments, confounds and texts
      • propensities: prechose propensity score $\pi(C)=P(A=1|C),\ C=0,1$
      • beta_t: the treatment level; if the outcome is continuous, then this is also the oracle NDE
      • beta_c: the confounding level
      • gamma: the level of noise
      • cts: outcome type (True: continous outcome; False: binary outcome)
      • return: simulated data frame; offset $E\left(\pi(C)\right)$
raw_df = pd.read_csv('./src/music.csv')
simulated_df, offset =run_simulation(raw_df, propensities=[0.8, 0.6], 
                            beta_t=1.0,  # 1.0, 0.0
                            beta_c=50.0  # 50.0, 100.0
                            gamma=1.0    # 1.0, 4.0
                            cts=True)  

Usage

Import functions.

from Qmod import *

Initialize the Q-Net model wrapper and train the model.

mod = QNet(batch_size = 64, # batch size for training
           a_weight = 0.1,  # loss weight for A ~ text
           y_weight = 0.1,  # loss weight for Y ~ A + text
           mlm_weight=1.0,  # loss weight for DistlBert
           modeldir='model/train') # directory for saving the best model
           
mod.train(df['text'],  # texts in training data
          df['T'],     # treatments in training data
          df['C'],     # confounds in training data, binary
          df['Y'],     # outcomes in training data
          epochs=20,   # the maximum number of training epochs
          learning_rate = 2e-5)  # learning rate for the training

Then, obtain the conditional oucomes for the test data.

Q0, Q1, A, Y, _ = mod.get_Q(test_df['text'], test_df['T'], test_df['C'], test_df['Y'])

Compute the estimated propensity scores.

g = get_propensities(A, Q0, Q1, 
                     model_type='GaussianProcessRegression', # choose the nonparametric model
                     kernel=None,    # kernel function for GPR
                     random_state=0) # random seed for GPR 

Get the TI estimator and its confidence interval.

get_TI_estimator(g, Q0, Q1, A, Y, 
                  error=0.05)  # error bound for confidence interval

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%