Numerical experiments I (training)

Notebook written by Matteo Sesia and Yaniv Romano

Stanford University, Department of Statistics

Last updated on: November 19, 2018

The purpose of this notebook is to allow the numerical experiments described in the paper to be reproduced easily. Running this code may take a few hours on a graphical graphical processing unit.

Load the required libraries

In [1]:
import numpy as np
from DeepKnockoffs import KnockoffMachine
from DeepKnockoffs import GaussianKnockoffs
import data
import parameters

Data generating model

We model $X \in \mathbb{R}^p $ as a multivariate Student's-t distribution, with $p=100$ and the covariance matrix of an auto-regressive process of order one. The default correlation parameter for this distribution is $\rho =0.5$ and the number of degrees of freedom $\nu = 3$.

In [2]:
# Number of features
p = 100

# Load the built-in multivariate Student's-t model and its default parameters
# The currently available built-in models are:
# - gaussian : Multivariate Gaussian distribution
# - gmm      : Gaussian mixture model
# - mstudent : Multivariate Student's-t distribution
# - sparse   : Multivariate sparse Gaussian distribution 
model = "mstudent"
distribution_params = parameters.GetDistributionParams(model, p)

# Initialize the data generator
DataSampler = data.DataSampler(distribution_params)

Let's sample $n=10000$ observations of $X$. This dataset will be used later to train a deep knockoff machine.

In [3]:
# Number of training examples
n = 10000

# Sample training data
X_train = DataSampler.sample(n)
print("Generated a training dataset of size: %d x %d." %(X_train.shape))
Generated a training dataset of size: 10000 x 100.

Second-order knockoffs

After computing the empirical covariance matrix of $X$ in the training dataset, we can initialize a generator of second-order knockoffs. The solution of the SDP determines the pairwise correlations between the original variables and the knockoffs produced by this algorithm.

In [4]:
# Compute the empirical covariance matrix of the training data
SigmaHat = np.cov(X_train, rowvar=False)

# Initialize generator of second-order knockoffs
second_order = GaussianKnockoffs(SigmaHat, mu=np.mean(X_train,0), method="sdp")

# Measure pairwise second-order knockoff correlations 
corr_g = (np.diag(SigmaHat) - np.diag(second_order.Ds)) / np.diag(SigmaHat)

print('Average absolute pairwise correlation: %.3f.' %(np.mean(np.abs(corr_g))))
Average absolute pairwise correlation: 0.526.

Deep knockoff machine

The default parameters of the machine are set below, as most appropriate for the specific built-in model considered. The figures in the paper were obtained by setting the number of epochs to 1000 and the learning rate to 0.001, while in order to reduce the runtime this notebook uses the values 100 and 0.01 respectively.

In [5]:
# Load the default hyperparameters for this model
training_params = parameters.GetTrainingHyperParams(model)

# Set the parameters for training deep knockoffs
pars = dict()
# Number of epochs
pars['epochs'] = 100
# Number of iterations over the full data per epoch
pars['epoch_length'] = 100
# Data type, either "continuous" or "binary"
pars['family'] = "continuous"
# Dimensions of the data
pars['p'] = p
# Size of the test set
pars['test_size']  = 0
# Batch size
pars['batch_size'] = int(0.5*n)
# Learning rate
pars['lr'] = 0.01
# When to decrease learning rate (unused when equal to number of epochs)
pars['lr_milestones'] = [pars['epochs']]
# Width of the network (number of layers is fixed to 6)
pars['dim_h'] = int(10*p)
# Penalty for the MMD distance
pars['GAMMA'] = training_params['GAMMA']
# Penalty encouraging second-order knockoffs
pars['LAMBDA'] = training_params['LAMBDA']
# Decorrelation penalty hyperparameter
pars['DELTA'] = training_params['DELTA']
# Target pairwise correlations between variables and knockoffs
pars['target_corr'] = corr_g
# Kernel widths for the MMD measure (uniform weights)
pars['alphas'] = [1.,2.,4.,8.,16.,32.,64.,128.]

The machine will be stored in the tmp/ subdirectory for later use and continuously updated during training after each epoch.

In [6]:
# Where to store the machine
checkpoint_name = "tmp/" + model

# Where to print progress information
logs_name = "tmp/" + model + "_progress.txt"
In [7]:
# Initialize the machine
machine = KnockoffMachine(pars, checkpoint_name=checkpoint_name, logs_name=logs_name)

Let's fit the machine to the training data. The value of the loss function on the training will be printed after each epoch, along with other diagnostics based on the MMD, the second moments and the pairwise correlations between variables and knockoffs.

In [8]:
# Train the machine
print("Fitting the knockoff machine...")
machine.train(X_train)
Fitting the knockoff machine...
[   1/ 100], Loss: 0.1876, MMD: 0.1726, Cov: 1.289, Decorr: 0.299
[   2/ 100], Loss: 0.1454, MMD: 0.1366, Cov: 0.927, Decorr: 0.441
[   3/ 100], Loss: 0.1332, MMD: 0.1265, Cov: 0.786, Decorr: 0.502
[   4/ 100], Loss: 0.1294, MMD: 0.1236, Cov: 0.730, Decorr: 0.536
[   5/ 100], Loss: 0.1272, MMD: 0.1220, Cov: 0.702, Decorr: 0.555
[   6/ 100], Loss: 0.1260, MMD: 0.1212, Cov: 0.671, Decorr: 0.564
[   7/ 100], Loss: 0.1253, MMD: 0.1207, Cov: 0.694, Decorr: 0.567
[   8/ 100], Loss: 0.1247, MMD: 0.1204, Cov: 0.680, Decorr: 0.569
[   9/ 100], Loss: 0.1238, MMD: 0.1198, Cov: 0.668, Decorr: 0.569
[  10/ 100], Loss: 0.1233, MMD: 0.1195, Cov: 0.650, Decorr: 0.572
[  11/ 100], Loss: 0.1230, MMD: 0.1194, Cov: 0.651, Decorr: 0.570
[  12/ 100], Loss: 0.1225, MMD: 0.1191, Cov: 0.647, Decorr: 0.572
[  13/ 100], Loss: 0.1224, MMD: 0.1190, Cov: 0.623, Decorr: 0.570
[  14/ 100], Loss: 0.1221, MMD: 0.1189, Cov: 0.633, Decorr: 0.567
[  15/ 100], Loss: 0.1220, MMD: 0.1189, Cov: 0.631, Decorr: 0.567
[  16/ 100], Loss: 0.1221, MMD: 0.1190, Cov: 0.625, Decorr: 0.564
[  17/ 100], Loss: 0.1217, MMD: 0.1187, Cov: 0.612, Decorr: 0.564
[  18/ 100], Loss: 0.1217, MMD: 0.1188, Cov: 0.596, Decorr: 0.561
[  19/ 100], Loss: 0.1214, MMD: 0.1185, Cov: 0.602, Decorr: 0.560
[  20/ 100], Loss: 0.1214, MMD: 0.1186, Cov: 0.592, Decorr: 0.560
[  21/ 100], Loss: 0.1212, MMD: 0.1184, Cov: 0.567, Decorr: 0.559
[  22/ 100], Loss: 0.1213, MMD: 0.1185, Cov: 0.584, Decorr: 0.559
[  23/ 100], Loss: 0.1209, MMD: 0.1183, Cov: 0.572, Decorr: 0.558
[  24/ 100], Loss: 0.1214, MMD: 0.1187, Cov: 0.557, Decorr: 0.557
[  25/ 100], Loss: 0.1209, MMD: 0.1183, Cov: 0.555, Decorr: 0.554
[  26/ 100], Loss: 0.1211, MMD: 0.1185, Cov: 0.548, Decorr: 0.554
[  27/ 100], Loss: 0.1211, MMD: 0.1185, Cov: 0.547, Decorr: 0.552
[  28/ 100], Loss: 0.1209, MMD: 0.1184, Cov: 0.551, Decorr: 0.551
[  29/ 100], Loss: 0.1206, MMD: 0.1181, Cov: 0.541, Decorr: 0.551
[  30/ 100], Loss: 0.1208, MMD: 0.1183, Cov: 0.528, Decorr: 0.550
[  31/ 100], Loss: 0.1206, MMD: 0.1181, Cov: 0.536, Decorr: 0.546
[  32/ 100], Loss: 0.1206, MMD: 0.1181, Cov: 0.537, Decorr: 0.546
[  33/ 100], Loss: 0.1205, MMD: 0.1181, Cov: 0.521, Decorr: 0.547
[  34/ 100], Loss: 0.1203, MMD: 0.1179, Cov: 0.529, Decorr: 0.549
[  35/ 100], Loss: 0.1203, MMD: 0.1179, Cov: 0.528, Decorr: 0.544
[  36/ 100], Loss: 0.1205, MMD: 0.1181, Cov: 0.528, Decorr: 0.544
[  37/ 100], Loss: 0.1203, MMD: 0.1180, Cov: 0.525, Decorr: 0.543
[  38/ 100], Loss: 0.1203, MMD: 0.1179, Cov: 0.522, Decorr: 0.540
[  39/ 100], Loss: 0.1206, MMD: 0.1183, Cov: 0.509, Decorr: 0.541
[  40/ 100], Loss: 0.1204, MMD: 0.1180, Cov: 0.507, Decorr: 0.540
[  41/ 100], Loss: 0.1202, MMD: 0.1179, Cov: 0.510, Decorr: 0.538
[  42/ 100], Loss: 0.1204, MMD: 0.1181, Cov: 0.512, Decorr: 0.538
[  43/ 100], Loss: 0.1202, MMD: 0.1179, Cov: 0.511, Decorr: 0.536
[  44/ 100], Loss: 0.1204, MMD: 0.1181, Cov: 0.507, Decorr: 0.534
[  45/ 100], Loss: 0.1202, MMD: 0.1179, Cov: 0.501, Decorr: 0.534
[  46/ 100], Loss: 0.1200, MMD: 0.1177, Cov: 0.499, Decorr: 0.533
[  47/ 100], Loss: 0.1201, MMD: 0.1179, Cov: 0.500, Decorr: 0.533
[  48/ 100], Loss: 0.1201, MMD: 0.1178, Cov: 0.501, Decorr: 0.533
[  49/ 100], Loss: 0.1199, MMD: 0.1177, Cov: 0.489, Decorr: 0.534
[  50/ 100], Loss: 0.1201, MMD: 0.1179, Cov: 0.484, Decorr: 0.531
[  51/ 100], Loss: 0.1199, MMD: 0.1176, Cov: 0.492, Decorr: 0.530
[  52/ 100], Loss: 0.1201, MMD: 0.1179, Cov: 0.495, Decorr: 0.531
[  53/ 100], Loss: 0.1197, MMD: 0.1174, Cov: 0.491, Decorr: 0.530
[  54/ 100], Loss: 0.1200, MMD: 0.1178, Cov: 0.486, Decorr: 0.529
[  55/ 100], Loss: 0.1199, MMD: 0.1177, Cov: 0.484, Decorr: 0.528
[  56/ 100], Loss: 0.1200, MMD: 0.1178, Cov: 0.480, Decorr: 0.529
[  57/ 100], Loss: 0.1198, MMD: 0.1176, Cov: 0.480, Decorr: 0.526
[  58/ 100], Loss: 0.1200, MMD: 0.1178, Cov: 0.484, Decorr: 0.526
[  59/ 100], Loss: 0.1202, MMD: 0.1180, Cov: 0.480, Decorr: 0.526
[  60/ 100], Loss: 0.1200, MMD: 0.1178, Cov: 0.478, Decorr: 0.524
[  61/ 100], Loss: 0.1198, MMD: 0.1176, Cov: 0.487, Decorr: 0.522
[  62/ 100], Loss: 0.1200, MMD: 0.1178, Cov: 0.470, Decorr: 0.523
[  63/ 100], Loss: 0.1194, MMD: 0.1173, Cov: 0.486, Decorr: 0.521
[  64/ 100], Loss: 0.1199, MMD: 0.1177, Cov: 0.469, Decorr: 0.521
[  65/ 100], Loss: 0.1200, MMD: 0.1178, Cov: 0.471, Decorr: 0.519
[  66/ 100], Loss: 0.1197, MMD: 0.1176, Cov: 0.481, Decorr: 0.521
[  67/ 100], Loss: 0.1197, MMD: 0.1175, Cov: 0.462, Decorr: 0.520
[  68/ 100], Loss: 0.1198, MMD: 0.1176, Cov: 0.473, Decorr: 0.520
[  69/ 100], Loss: 0.1198, MMD: 0.1177, Cov: 0.473, Decorr: 0.519
[  70/ 100], Loss: 0.1199, MMD: 0.1178, Cov: 0.465, Decorr: 0.516
[  71/ 100], Loss: 0.1196, MMD: 0.1175, Cov: 0.467, Decorr: 0.518
[  72/ 100], Loss: 0.1199, MMD: 0.1177, Cov: 0.479, Decorr: 0.516
[  73/ 100], Loss: 0.1199, MMD: 0.1177, Cov: 0.472, Decorr: 0.516
[  74/ 100], Loss: 0.1200, MMD: 0.1178, Cov: 0.462, Decorr: 0.514
[  75/ 100], Loss: 0.1196, MMD: 0.1174, Cov: 0.461, Decorr: 0.517
[  76/ 100], Loss: 0.1200, MMD: 0.1179, Cov: 0.463, Decorr: 0.514
[  77/ 100], Loss: 0.1196, MMD: 0.1175, Cov: 0.464, Decorr: 0.513
[  78/ 100], Loss: 0.1194, MMD: 0.1173, Cov: 0.474, Decorr: 0.513
[  79/ 100], Loss: 0.1199, MMD: 0.1177, Cov: 0.463, Decorr: 0.513
[  80/ 100], Loss: 0.1195, MMD: 0.1174, Cov: 0.462, Decorr: 0.512
[  81/ 100], Loss: 0.1194, MMD: 0.1173, Cov: 0.456, Decorr: 0.514
[  82/ 100], Loss: 0.1200, MMD: 0.1179, Cov: 0.460, Decorr: 0.508
[  83/ 100], Loss: 0.1196, MMD: 0.1175, Cov: 0.470, Decorr: 0.508
[  84/ 100], Loss: 0.1197, MMD: 0.1176, Cov: 0.456, Decorr: 0.510
[  85/ 100], Loss: 0.1197, MMD: 0.1177, Cov: 0.461, Decorr: 0.509
[  86/ 100], Loss: 0.1197, MMD: 0.1176, Cov: 0.465, Decorr: 0.506
[  87/ 100], Loss: 0.1193, MMD: 0.1172, Cov: 0.458, Decorr: 0.509
[  88/ 100], Loss: 0.1196, MMD: 0.1175, Cov: 0.461, Decorr: 0.507
[  89/ 100], Loss: 0.1197, MMD: 0.1176, Cov: 0.459, Decorr: 0.506
[  90/ 100], Loss: 0.1195, MMD: 0.1174, Cov: 0.453, Decorr: 0.507
[  91/ 100], Loss: 0.1195, MMD: 0.1174, Cov: 0.457, Decorr: 0.507
[  92/ 100], Loss: 0.1199, MMD: 0.1178, Cov: 0.450, Decorr: 0.505
[  93/ 100], Loss: 0.1192, MMD: 0.1172, Cov: 0.448, Decorr: 0.506
[  94/ 100], Loss: 0.1197, MMD: 0.1176, Cov: 0.448, Decorr: 0.503
[  95/ 100], Loss: 0.1192, MMD: 0.1172, Cov: 0.448, Decorr: 0.504
[  96/ 100], Loss: 0.1195, MMD: 0.1174, Cov: 0.448, Decorr: 0.503
[  97/ 100], Loss: 0.1194, MMD: 0.1174, Cov: 0.447, Decorr: 0.503
[  98/ 100], Loss: 0.1192, MMD: 0.1172, Cov: 0.456, Decorr: 0.504
[  99/ 100], Loss: 0.1195, MMD: 0.1174, Cov: 0.458, Decorr: 0.502
[ 100/ 100], Loss: 0.1195, MMD: 0.1174, Cov: 0.441, Decorr: 0.503