Data carving

web.stanford.edu/class/stats364/

Jonathan Taylor

Spring 2020

Fithian, Sun, T. (2014)

Data carving

Fithian, Sun, T. (2014)

Data carving

Fithian, Sun, T. (2014)

Data carving

Fithian, Sun, T. (2014)

Data carving with LASSO

Fithian, Sun, T. (2014)

Data carving with LASSO: reparametrization

Tian et al. (2016)

Reparametrization

import numpy as np, pandas as pd, seaborn as sns
from selectinf.randomized import lasso
from selectinf.tests.instance import gaussian_instance
from selectinf.learning.utils import (split_partial_model_inference,
                                      split_full_model_inference)
## /Users/jonathantaylor/git-repos/selectinf/selectinf/learning/keras_fit.py:50: UserWarning: module `keras` not importable, `keras_fit` and `keras_fit_multilabel` will not be importable
##   warnings.warn('module `keras` not importable, `keras_fit` and `keras_fit_multilabel` will not be importable')
import matplotlib.pyplot as plt
import statsmodels.api as sm
ECDF = sm.distributions.ECDF
# %matplotlib inline
S = simulate()
## /Users/jonathantaylor/opt/anaconda3/lib/python3.7/site-packages/pandas/core/computation/expressions.py:194: UserWarning: evaluating in Python space because the '*' operator is not supported by numexpr for the bool dtype, use '&' instead
##   op=op_str, alt_op=unsupported[op_str]
S
##          MLE  true target  cover        SE  ...  split_lower  split_upper    target  variable
## 0  -1.276152     0.333526   True  1.228676  ...    -3.388519     1.326546  0.214730         2
## 1  -0.720044    -0.139625   True  1.128729  ...    -2.411866     2.266531 -0.056667         4
## 2   0.284744     0.225595   True  1.301743  ...    -1.166275     3.973852  0.333505         8
## 3   3.840087     3.043339   True  1.114823  ...     1.022116     5.892335  2.918780         9
## 4   2.105436     3.162512   True  1.241932  ...     1.267002     5.862226  3.002715        11
## 5   4.100433     2.928182   True  1.043039  ...     2.042312     6.637502  3.178533        15
## 6   3.932669     2.997871   True  1.092015  ...    -0.769532     4.300592  2.841684        17
## 7   0.875019     0.018633   True  1.266655  ...    -0.591187     4.607814  0.073290        18
## 8   0.546038     0.053838   True  1.096639  ...    -3.405330     0.975307  0.110363        21
## 9   3.210836     3.059558   True  1.155215  ...     2.161504     6.943532  3.051207        27
## 10  4.129639     2.777719   True  1.057458  ...     1.888165     6.986781  2.785616        28
## 11  2.370937     3.011855   True  1.134471  ...     0.178898     4.659832  2.869338        32
## 12  3.885217     2.792739   True  1.021503  ...     2.360765     6.885176  2.766129        34
## 13  3.410596     3.004038   True  1.049087  ...     1.370267     5.968115  3.134744        38
## 14  0.111909     0.230520   True  1.189844  ...    -4.015998     0.624091  0.155777        40
## 15  0.864049    -0.065623   True  1.062754  ...    -2.826887     1.802275 -0.118068        42
## 16  1.241298     0.625545   True  1.225506  ...     0.413445     5.157180  0.478685        45
## 
## [17 rows x 18 columns]
results = []
for _ in range(50):
    results.append(simulate())
results = pd.concat(results)
lengths = pd.DataFrame({'split':results['split_upper'] - results['split_lower'],
                        'carve':results['upper'] - results['lower']})
ax = sns.distplot(lengths['split'], label='Split')
sns.distplot(lengths['carve'], ax=ax, label='Carve')
## <matplotlib.axes._subplots.AxesSubplot object at 0x11b5e7210>
ax.legend()
## <matplotlib.legend.Legend object at 0x1c233f3490>
np.mean(results['cover']), np.mean(results['split_coverage'])
## (0.9259818731117825, 0.8897280966767371)
results0 = results.loc[lambda df: df['true target'] == 0]
print(np.mean(results0['cover']), np.mean(results0['split_coverage']))
## 0.9090909090909091 0.8181818181818182
print(np.mean(results0['pvalue'] < 0.05), np.mean(results0['split_pvalue'] < 0.05))
## 0.0 0.0
pvalue_plot(results0)
resultsA = results.loc[lambda df: df['true target'] != 0]
print(np.mean(resultsA['cover']), np.mean(resultsA['split_coverage']))
## 0.9262672811059908 0.890937019969278
print(np.mean(resultsA['pvalue'] < 0.05), np.mean(resultsA['split_pvalue'] < 0.05))
## 0.45161290322580644 0.34715821812596004
pvalue_plot(resultsA)