Conditional inference II (LASSO)

web.stanford.edu/class/stats364/

Jonathan Taylor

Spring 2020

Conditional inference II

LASSO

library(glmnet)
## Loading required package: Matrix
## Loading required package: foreach
## Loaded glmnet 2.0-18
set.seed(2)
n = 100
p = 20
X = matrix(rnorm(n * p), n, p)
Y = rnorm(n)
L = glmnet(X, Y, intercept=FALSE, standardize=FALSE)
# ignore intercept
beta_hat = coef(L, s=1.5/sqrt(n), exact=TRUE, x=X, y=Y)[-1] 

Conditional inference II

LASSO with glmnet

LASSO

Conditional inference II

LASSO

selected = which(beta_hat != 0)
list(selected=selected)
## $selected
## [1]  2 20
confint(lm(Y ~ X[,selected] - 1))
##                     2.5 %      97.5 %
## X[, selected]1 -0.4059135  0.03010663
## X[, selected]2 -0.4949535 -0.07162624

Conditional inference II

LASSO

simulate = function(n=100,
                    p=20,
                    s=1.5) {
    X = matrix(rnorm(n * p), n, p)
    Y = rnorm(n)
    L = glmnet(X, Y, intercept=FALSE, standardize=FALSE)
    # ignore intercept
    beta_hat = coef(L, s=s/sqrt(n), exact=TRUE, x=X, y=Y)[-1] 
    selected = beta_hat != 0
    if (sum(selected) > 0) { 
        intervals = confint(lm(Y ~ X[,selected] - 1))
        covered = (intervals[,1] < 0) * (intervals[,2] > 0)
        return(covered)
    } else {
        return(numeric(0))
    }
}
simulate()
## X[, selected]1 X[, selected]2 
##              0              0

Conditional inference II

A little simulation

covered = c()
for (i in 1:100) {
    covered = c(covered, simulate())
}
mean(covered)
## [1] 0.592

Conditional inference II

LASSO

KKT conditions

Conditional inference II

Subgradient equations

Conditional inference

Subgradient equations in block form

Conditional inference II

Selection event

Conditioning on signs

signs = sign(beta_hat[selected])
list(selected=selected,
     signs=signs)
## $selected
## [1]  2 20
## 
## $signs
## [1] -1 -1

Conditional inference II

Inference requires a model!

Saturated model (PoSI, Lee et al. (2016))

Conditional model

Conditional inference II

Conditional inference for \(H_{0,\delta}:\eta^T\mu=\delta\)?

Conditional inference II

Workhorse

Conditional inference II

Conditional inference II

Affine representation for \(Y\)

Conditional inference II

Polyhedral lemma

Conditional inference II

Polyhedral lemma (corollary)

Conditional inference II

Marginalizing over signs

Conditional inference II

Not all signs vectors are realizable

Conditional inference II

How to compute the intersection (polyhedral lemma)

Conditional inference II

Results!

X = np.load('X.npy')
Y = np.load('Y.npy')
from selectinf.algorithms.lasso import lasso
n, p = X.shape
L = lasso.gaussian(X, Y, 1.5 * np.sqrt(X.shape[0]), sigma=1.)
L.fit()
## array([ 0.        , -0.01205666, -0.        , -0.        , -0.        ,
##         0.        , -0.        , -0.        ,  0.        , -0.        ,
##         0.        ,  0.        ,  0.        , -0.        ,  0.        ,
##         0.        ,  0.        ,  0.        , -0.        , -0.11651073])
S = L.summary(compute_intervals=True,
              alternative='onesided',
              level=0.90)
S
##     variable     pval     lasso   onestep  ...  upper_confidence  lower_trunc  upper_trunc        sd
## 1          1  0.77475 -0.012057 -0.187903  ...          2.432066         -inf    -0.175847  0.102646
## 19        19  0.06657 -0.116511 -0.283290  ...          0.030983         -inf    -0.182397  0.099658
## 
## [2 rows x 9 columns]

Conditional inference II

One-sided \(p\)-value computation

from scipy.stats import norm as normal_dbn
Z = S['onestep'] / S['sd']
UZ = S['upper_trunc'] / S['sd']
LZ = S['lower_trunc'] / S['sd']
# here both observed signs are -1
pval = ((normal_dbn.cdf(Z) - normal_dbn.cdf(LZ)) / 
        (normal_dbn.cdf(UZ) - normal_dbn.cdf(LZ)))
pval
## array([0.77475002, 0.06656959])

Two-sided \(p\)-value

twosided = np.asarray(L.summary(alternative='twosided')['pval'])
twosided, 2 * np.minimum(pval, 1 - pval)
## (array([0.45049997, 0.13313918]), array([0.45049997, 0.13313918]))

Conditional inference II

Using R

library(selectiveInference)
## Loading required package: intervals
## 
## Attaching package: 'intervals'
## The following object is masked from 'package:Matrix':
## 
##     expand
## Loading required package: survival
## Loading required package: adaptMCMC
## Loading required package: parallel
## Loading required package: coda
## Loading required package: MASS
selectiveInference::fixedLassoInf(X, 
                                  Y,
                                  beta_hat,
                                  1.5 * sqrt(n),
                                  intercept=FALSE,
                                  alpha=0.1,
                                  sigma=1)
## 
## Call:
## selectiveInference::fixedLassoInf(x = X, y = Y, beta = beta_hat, 
##     lambda = 1.5 * sqrt(n), intercept = FALSE, sigma = 1, alpha = 0.1)
## 
## Standard deviation of noise (specified or estimated) sigma = 1.000
## 
## Testing results at lambda = 15.000, with alpha = 0.100
## 
##  Var   Coef Z-score P-value LowConfPt UpConfPt LowTailArea UpTailArea
##    2 -0.188  -1.831   0.775    -0.250    2.433       0.049       0.05
##   20 -0.283  -2.843   0.067    -0.444    0.032       0.049       0.05
## 
## Note: coefficients shown are partial regression coefficients

Conditional inference II

Conditional inference for LASSO

Conditional inference II

Conditional inference II

Choice of model

Conditional inference II

Nuisance parameters

Exercise

Verify the equality in law above, as well as the conclusion.

Conditional inference II

YACM: yet another choice of model

Exercise

Verify that using \({\cal M}_E\) and corresponding conditional model \({\cal M}_E^*\) yields the same reference distribution as starting with \({\cal M}\) or \({\cal M}_f\).

Conditional inference II

Choice of target parameter \({\cal \theta}_T\)

Conditional inference II

Pairs model

Conditional inference II

Linear decomposition

Conditional inference II

Linear decomposition

Conditional inference II

General form of polyhedral lemma

For event \(\{D:AD \leq b\}\).

  1. Set \[ \Gamma = \Gamma(F) = \text{Cov}_F(D, \hat{\theta}_{\cal T}) \text{Cov}_F(\hat{\theta}_{\cal T})^{-1} \]

  2. Compute \({\cal N} = D - \Gamma(F) \hat{\theta}_{\cal T}\) \[ \begin{aligned} {\cal V}^L({\cal N}, \Gamma) &= \max_{j: (A\Gamma)_j < 0} \frac{(b - A{\cal N})_j}{(A \Gamma)_j} \\ {\cal V}^U({\cal N}, \Gamma) &= \min_{j: (A\Gamma)_j > 0} \frac{(b - A{\cal N})_j}{(A \Gamma)_j} \\ \end{aligned} \]

Conditional inference II

A more careful analysis of pairs model (fixed \(p\))

Conditional inference II

A more careful analysis of pairs model (fixed \(p\))

Conditional inference II

A more careful analysis of pairs model (fixed \(p\))

Conditional inference II

A more careful analysis of pairs model (fixed \(p\))

A subtle point

Conditional inference II

A more careful analysis of pairs model (fixed \(p\))

Exercise

By run the polyhedral lemma we mean compute \({\cal V}^{L/U}(D - \Gamma(\hat{\Sigma}) \hat{\theta}_{\cal T}, \Gamma(\hat{\Sigma}))\). Prove the claim above, i.e. with the the observed value of \(D\) as above, show that \(\hat{\theta}_{\cal T}\) is always in the interval \({\cal V}^{L,U}\). Supposing that after rescaling \(\text{Var}_F(\hat{\theta}_{\cal T}) = O(1)\) how much do you expect the endpoints \({\cal V}^{L/U}\) to change if you used the unobservable \(D=(\hat{\theta}_{\cal T}, G^A_{(E,s_E)}, G^I_{(E,s_E)})\) instead of \((\hat{\theta}_{\cal T}, n^{1/2}\hat{\beta}_E, n^{-1/2} \nabla \ell(\hat{\beta})[-E])\) under a no rare selection assumption? How large are these changes compared to the typical sizes of \({\cal V}^{L/U}\)? (Remember under no rare selection the “sizes” of random variables are the same under \(F\) as \(F^*\).)

Conditional inference II

Logistic LASSO

Conditional inference II

Restricted problem

Conditional inference II

Exercise

Suppose \(U\) has decided, after seeing the logistic LASSO has selected \((E, s_E)\) that the proper target of inference is the \(k\)-th coordinate of the probit parameter of the true data generating distribution \(F\) in the pairs model with features \(E'\). Describe how to construct an (asymptotically) valid confidence interval for this parameter that “runs”, i.e. does not give nonsensical results such as \(\hat{\theta}_{\cal T}\) is outside its truncation interval.

Conditional inference II

Selected model framework T. and Tibshirani (2018)

Conditional inference II

Logistic inference under null (so \(\beta=0\))

def simulate(n=500, p=50, C=1):
    
    X = np.random.standard_normal((n, p))
    Y = np.random.binomial(1, 0.5, size=(n,))
    W = np.ones(p) * C * np.sqrt(n)
    L = lasso.logistic(X, Y, W)
    beta_hat = L.fit()
    if (beta_hat != 0).sum() > 0:
        S = L.summary(compute_intervals=True,
                      alternative='twosided',
                      level=0.90)

        return S

simulate()
##     variable      pval     lasso   onestep  ...  upper_confidence  lower_trunc  upper_trunc        sd
## 17        17  0.016169  0.143245  0.327934  ...          0.496885     0.184689     0.533234  0.091908
## 18        18  0.348696  0.007160  0.176163  ...          0.197229     0.169004     3.912629  0.087659
## 26        26  0.759404 -0.032987 -0.209377  ...          0.487633         -inf    -0.176390  0.087252
## 
## [3 rows x 9 columns]
P, C = [], []
for _ in range(100):
    S = simulate()
    if S is not None:
        P.extend(list(S['pval']))
        C.extend(list((S['lower_confidence'] < 0) & (S['upper_confidence'] > 0)))
print('coverage: (target 90%): ', np.mean(C))
## coverage: (target 90%):  0.9234234234234234

Conditional inference II

Pivot (p-value) plot

Conditional inference II

Exercise

Verify the above claimed asymptotic independence of \((\hat{\theta}_{\cal T}, G^A_{(E,s_E)})\) and \(G^I_{(E,s_E)}\) under \({\cal M}_{P,E}\) and \({\cal M}_{P,E}^*\) under an appropriate no rare selection assumption for \((E,s_E)\).

Conditional inference II

Exercise

For either the OLS or Logistic LASSO, suppose we want to marginalize over signs \(s_E\). Describe an algorithm using the polyhedral lemma for each \(s_E\) that will “run” (i.e. not give nonsensical answers due to \(\hat{\theta}_{\cal T}\) being outside its truncation set) and give (asymptotically) valid inference conditional only on the event \(\hat{E}=E\).

Conditional inference II

Unpenalized variables, feature weights

Conditional inference II

Unpenalized variables, feature weights

Xi = np.hstack([np.ones((n, 1)),
                X])
weights = 1.5 * np.sqrt(n) * np.ones(p+1)
weights[0] = 0
Li = lasso.gaussian(Xi, Y, weights, sigma=1.)
Li.fit()
## array([ 0.08047905,  0.        , -0.01307299, -0.        , -0.        ,
##        -0.        ,  0.        , -0.        , -0.        ,  0.00087677,
##        -0.        ,  0.        ,  0.        ,  0.        , -0.        ,
##         0.        ,  0.        ,  0.        ,  0.        , -0.        ,
##        -0.1046863 ])
Li.summary(compute_intervals=True,
           alternative='onesided',
           level=0.90)
##     variable      pval     lasso   onestep  ...  upper_confidence  lower_trunc  upper_trunc        sd
## 0          0  0.488701  0.080479  0.061638  ...          0.218739    -0.013930     0.773783  0.101140
## 2          2  0.624802 -0.013073 -0.184972  ...          2.233520    -0.217428    -0.171899  0.102695
## 9          9  0.984540  0.000877  0.127013  ...         -0.427391     0.126137     0.574571  0.098759
## 20        20  0.018953 -0.104686 -0.260735  ...         -0.152579    -0.269230    -0.156049  0.101312
## 
## [4 rows x 9 columns]

Conditional inference II

Unpenalized variables, feature weights

def simulate_intercept(n=500, p=50, C=1):
    
    X = np.random.standard_normal((n, p+1))
    X[:,0] = 1
    Y = np.random.binomial(1, 0.6, size=(n,))
    W = np.ones(p+1) * C * np.sqrt(n)
    W[0] = 0
    L = lasso.logistic(X, Y, W)
    beta_hat = L.fit()
    if (beta_hat != 0).sum() > 1:
        S = L.summary(compute_intervals=True,
                      alternative='twosided',
                      level=0.90)

        return S.iloc[1:] # remove intercept row

simulate_intercept()
##     variable      pval     lasso   onestep  ...  upper_confidence  lower_trunc  upper_trunc        sd
## 18        18  0.895433 -0.029592 -0.225419  ...          0.688152    -0.579943    -0.195827  0.094736
## 21        21  0.197391 -0.078194 -0.246453  ...          0.076704    -0.368645    -0.168259  0.089775
## 36        36  0.356272  0.007813  0.179142  ...          0.205369     0.171329     1.432253  0.091516
## 
## [3 rows x 9 columns]
P, C = [], []
for _ in range(100):
    S = simulate_intercept()
    if S is not None:
        P.extend(list(S['pval']))
        C.extend(list((S['lower_confidence'] < 0) & (S['upper_confidence'] > 0)))

Conditional inference II

Unpenalized variables, feature weights

print('coverage: (target 90%): ', np.mean(C))
## coverage: (target 90%):  0.9282051282051282

Conditional inference II

Pivot (p-value) plot

Conditional inference II

Full model inference for LASSO Liu, Markovic and Tibshirani (2018)

Up to now, we’ve allowed target to be based on \((\hat{E},\hat{s}_E)\) for both OLS and logistic LASSO but this is not strictly necessary if:

  1. We are willing to assume the full parametric model is correct e.g. for OLS and \(n \gg p\) \[ {\cal L}(Y|X) = N(X\beta^*, \sigma^2 I) \in {\cal M}_f \]

  2. We restrict attention to targets \(\beta_j = \beta_{j|\{1,\dots,p\}}\).

Inference is based on conditioning on the outcome of queries \(Q_j(X,Y) = 1_{\hat{E}(X,Y)}(j)\).

Comments

Conditional inference II

Selection event for full model inference

Conditional inference II

Selection event for full model inference

Conditional inference II

Exercise

Suppose \(X\) is fixed. Verify that if \(D=X^TY\) and \(\theta_{\cal T}=\beta_j\) that \({\cal N} \equiv X_{-j}^TY\) by showing that our usual definition of \({\cal N}\) is in 1:1 correspondence with \(X_{-j}^TY\). More formally, in terms of conditioning, argue that the sigma algebra generated by \({\cal N}\) is the same as that generated by \(X_{-j}^TY\).

Conditional inference II

Selection event for full model inference

from selectinf.algorithms.lasso import ROSI

def simulate_full(n=500, p=50, C=1.5):
    
    X = np.random.standard_normal((n, p))
    Y = np.random.standard_normal(n)
    W = np.ones(p) * C * np.sqrt(n)
    L = ROSI.gaussian(X, Y, W)
    beta_hat = L.fit()
    if (beta_hat != 0).sum() > 0:
        S = L.summary(compute_intervals=True,
                      level=0.90)

        return S # remove intercept row

simulate_full()
##     variable      pval     lasso  ...  upper_confidence  lower_truncation  upper_truncation
## 6          6  0.187981 -0.040119  ...          0.012707         -0.057826          0.085826
## 17        17  0.261607 -0.027967  ...          0.015356         -0.069161          0.080316
## 25        25  0.478552  0.010119  ...          0.131542         -0.064073          0.071951
## 29        29  0.998388  0.016198  ...          0.124809         -0.109881          0.044545
## 35        35  0.756812 -0.001374  ...          0.039444         -0.072520          0.087578
## 39        39  0.781752 -0.003502  ...          0.040540         -0.069454          0.089908
## 45        45  0.810212 -0.016688  ...          0.038102         -0.049755          0.086653
## 
## [7 rows x 9 columns]
P, C, L = [], [], []
for _ in range(100):
    S = simulate_full()
    if S is not None:
        L.extend((S['upper_confidence'] - S['lower_confidence']) /
                 (2 * normal_dbn.ppf(0.95) * S['sd']))
        P.extend(list(S['pval']))
        C.extend(list((S['lower_confidence'] < 0) & (S['upper_confidence'] > 0)))
print('coverage: (target 90%): ', np.mean(C))
## coverage: (target 90%):  0.898021308980213

Conditional inference II

Pivot (p-value) plot

Conditional inference II

Length of intervals compared to naive

Conditional inference II

Why are they sometimes shorter?

Conditional inference II

Selection event for full model inference Markovic (2019)

Conditional inference II

Selection event for full model inference

Conditional inference II

Logistic full model inference Markovic (2019)

from selectinf.algorithms.lasso import ROSI

def simulate_full_logistic(n=500, p=50, C=1):
    
    X = np.random.standard_normal((n, p+1)); X[:,0] = 1
    W = C * np.ones(p+1) * np.sqrt(n)
    W[0] = 0
    Y = np.random.binomial(1, 0.6, size=(n,))
    L = ROSI.logistic(X, Y, W)
    beta_hat = L.fit()
    if (beta_hat != 0).sum() > 1:
        S = L.summary(compute_intervals=True,
                      level=0.90)

        return S.iloc[1:] # remove intercept

simulate_full_logistic()
P, C = [], []
for _ in range(100):
    S = simulate_full_logistic()
    if S is not None:
        P.extend(list(S['pval']))
        C.extend(list((S['lower_confidence'] < 0) & (S['upper_confidence'] > 0)))
print('coverage: (target 90%): ', np.mean(C))
## coverage: (target 90%):  0.892018779342723

Conditional inference II

Pivot (p-value) plot

Conditional inference II

Exercise

  1. Explain how you can be sure this algorithm will “run” (i.e. not produce nonsenical results)?

  2. What will the confidence intervals look “more” like – Zhong and Prentice’s ones based on two-sided truncation or the ones in the file drawer example based on one-sided truncation?

  3. Do you expect this algorithm to produce valid conditional inference with \(p\) fixed and \(\lambda = n^{1/2} C\)?

Conditional inference II

Exercise

Returning to high dimensional \(p > n\) or \(p \gg n\) setting, let’s revisit our algorithms for the LASSO conditioning on \((\hat{E}, \hat{s}_E)\).

  1. Supposing one can obtain some reasonable estimate of \(\sigma^2\), what is the rough computational cost of the computing a \(p\)-value for testing \(H_0:\beta_{j|E}=\delta\) in \({\cal M}\)? In \({\cal M}_f\)?

  2. What is the rough computational cost in model \({\cal M}_{NP}\)?

  3. Supposing we make sparsity assumptions so that, with high probability \(|\hat{E}| \ll n < p\) do these algorithms require estimating the inverse of poorly conditioned matrices (under reasonable non-degeneracy assumptions)?

  4. In summary, do these algorithms “run” in the \(p > n\) setting (under such sparsity and non-degeneracy assumptions)? How about \(p \gg n\)?

  5. (Open-ended) Do you expect them to “work”, i.e. provide valid conditional inference (under such sparsity and non-degeneracy assumptions)?

Conditional inference II

Take aways