1 /*
2 * Solve a distributed lasso problem, i.e.,
3 *
4 * minimize (1/2)||Ax - b||_2^2 + lambda*||x||_1.
5 *
6 * The implementation uses MPI for distributed communication
7 * and the GNU Scientific Library (GSL) for math.
8 */
9
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <math.h>
13 #include "mmio.h"
14 #include <mpi.h>
15 #include <gsl/gsl_vector.h>
16 #include <gsl/gsl_matrix.h>
17 #include <gsl/gsl_blas.h>
18 #include <gsl/gsl_linalg.h>
19
20 void soft_threshold(gsl_vector *v, double k);
21 double objective(gsl_matrix *A, gsl_vector *b, double lambda, gsl_vector *z);
22
23 int main(int argc, char **argv) {
24 const int MAX_ITER = 50;
25 const double RELTOL = 1e-2;
26 const double ABSTOL = 1e-4;
27
28 /*
29 * Some bookkeeping variables for MPI. The 'rank' of a process is its numeric id
30 * in the process pool. For example, if we run a program via `mpirun -np 4 foo', then
31 * the process ranks are 0 through 3. Here, N and size are the total number of processes
32 * running (in this example, 4).
33 */
34
35 int rank;
36 int size;
37
38 MPI_Init(&argc, &argv); // Initialize the MPI execution environment
39 MPI_Comm_rank(MPI_COMM_WORLD, &rank); // Determine current running process
40 MPI_Comm_size(MPI_COMM_WORLD, &size); // Total number of processes
41 double N = (double) size; // Number of subsystems/slaves for ADMM
42
43 /* Read in local data */
44
45 int skinny; // A flag indicating whether the matrix A is fat or skinny
46 FILE *f;
47 int m, n;
48 int row, col;
49 double entry;
50
51 /*
52 * Subsystem n will look for files called An.dat and bn.dat
53 * in the current directory; these are its local data and do not need to be
54 * visible to any other processes. Note that
55 * m and n here refer to the dimensions of the *local* coefficient matrix.
56 */
57
58 /* Read A */
59 char s[20];
60 sprintf(s, "data/A%d.dat", rank + 1);
61 printf("[%d] reading %s\n", rank, s);
62
63 f = fopen(s, "r");
64 if (f == NULL) {
65 printf("[%d] ERROR: %s does not exist, exiting.\n", rank, s);
66 exit(EXIT_FAILURE);
67 }
68 mm_read_mtx_array_size(f, &m, &n);
69 gsl_matrix *A = gsl_matrix_calloc(m, n);
70 for (int i = 0; i < m*n; i++) {
71 row = i % m;
72 col = floor(i/m);
73 fscanf(f, "%lf", &entry);
74 gsl_matrix_set(A, row, col, entry);
75 }
76 fclose(f);
77
78 /* Read b */
79 sprintf(s, "data/b%d.dat", rank + 1);
80 printf("[%d] reading %s\n", rank, s);
81
82 f = fopen(s, "r");
83 if (f == NULL) {
84 printf("[%d] ERROR: %s does not exist, exiting.\n", rank, s);
85 exit(EXIT_FAILURE);
86 }
87 mm_read_mtx_array_size(f, &m, &n);
88 gsl_vector *b = gsl_vector_calloc(m);
89 for (int i = 0; i < m; i++) {
90 fscanf(f, "%lf", &entry);
91 gsl_vector_set(b, i, entry);
92 }
93 fclose(f);
94
95 m = A->size1;
96 n = A->size2;
97 skinny = (m >= n);
98
99 /*
100 * These are all variables related to ADMM itself. There are many
101 * more variables than in the Matlab implementation because we also
102 * require vectors and matrices to store various intermediate results.
103 * The naming scheme follows the Matlab version of this solver.
104 */
105
106 double rho = 1.0;
107
108 gsl_vector *x = gsl_vector_calloc(n);
109 gsl_vector *u = gsl_vector_calloc(n);
110 gsl_vector *z = gsl_vector_calloc(n);
111 gsl_vector *y = gsl_vector_calloc(n);
112 gsl_vector *r = gsl_vector_calloc(n);
113 gsl_vector *zprev = gsl_vector_calloc(n);
114 gsl_vector *zdiff = gsl_vector_calloc(n);
115
116 gsl_vector *q = gsl_vector_calloc(n);
117 gsl_vector *w = gsl_vector_calloc(n);
118 gsl_vector *Aq = gsl_vector_calloc(m);
119 gsl_vector *p = gsl_vector_calloc(m);
120
121 gsl_vector *Atb = gsl_vector_calloc(n);
122
123 double send[3]; // an array used to aggregate 3 scalars at once
124 double recv[3]; // used to receive the results of these aggregations
125
126 double nxstack = 0;
127 double nystack = 0;
128 double prires = 0;
129 double dualres = 0;
130 double eps_pri = 0;
131 double eps_dual = 0;
132
133 /* Precompute and cache factorizations */
134
135 gsl_blas_dgemv(CblasTrans, 1, A, b, 0, Atb); // Atb = A^T b
136
137 /*
138 * The lasso regularization parameter here is just hardcoded
139 * to 0.5 for simplicity. Using the lambda_max heuristic would require
140 * network communication, since it requires looking at the *global* A^T b.
141 */
142
143 double lambda = 0.5;
144 if (rank == 0) {
145 printf("using lambda: %.4f\n", lambda);
146 }
147
148 gsl_matrix *L;
149
150 /* Use the matrix inversion lemma for efficiency; see section 4.2 of the paper */
151 if (skinny) {
152 /* L = chol(AtA + rho*I) */
153 L = gsl_matrix_calloc(n,n);
154
155 gsl_matrix *AtA = gsl_matrix_calloc(n,n);
156 gsl_blas_dsyrk(CblasLower, CblasTrans, 1, A, 0, AtA);
157
158 gsl_matrix *rhoI = gsl_matrix_calloc(n,n);
159 gsl_matrix_set_identity(rhoI);
160 gsl_matrix_scale(rhoI, rho);
161
162 gsl_matrix_memcpy(L, AtA);
163 gsl_matrix_add(L, rhoI);
164 gsl_linalg_cholesky_decomp(L);
165
166 gsl_matrix_free(AtA);
167 gsl_matrix_free(rhoI);
168 } else {
169 /* L = chol(I + 1/rho*AAt) */
170 L = gsl_matrix_calloc(m,m);
171
172 gsl_matrix *AAt = gsl_matrix_calloc(m,m);
173 gsl_blas_dsyrk(CblasLower, CblasNoTrans, 1, A, 0, AAt);
174 gsl_matrix_scale(AAt, 1/rho);
175
176 gsl_matrix *eye = gsl_matrix_calloc(m,m);
177 gsl_matrix_set_identity(eye);
178
179 gsl_matrix_memcpy(L, AAt);
180 gsl_matrix_add(L, eye);
181 gsl_linalg_cholesky_decomp(L);
182
183 gsl_matrix_free(AAt);
184 gsl_matrix_free(eye);
185 }
186
187 /* Main ADMM solver loop */
188
189 int iter = 0;
190 if (rank == 0) {
191 printf("%3s %10s %10s %10s %10s %10s\n", "#", "r norm", "eps_pri", "s norm", "eps_dual", "objective");
192 }
193
194 while (iter < MAX_ITER) {
195
196 /* u-update: u = u + x - z */
197 gsl_vector_sub(x, z);
198 gsl_vector_add(u, x);
199
200 /* x-update: x = (A^T A + rho I) \ (A^T b + rho z - y) */
201 gsl_vector_memcpy(q, z);
202 gsl_vector_sub(q, u);
203 gsl_vector_scale(q, rho);
204 gsl_vector_add(q, Atb); // q = A^T b + rho*(z - u)
205
206 if (skinny) {
207 /* x = U \ (L \ q) */
208 gsl_linalg_cholesky_solve(L, q, x);
209 } else {
210 /* x = q/rho - 1/rho^2 * A^T * (U \ (L \ (A*q))) */
211 gsl_blas_dgemv(CblasNoTrans, 1, A, q, 0, Aq);
212 gsl_linalg_cholesky_solve(L, Aq, p);
213 gsl_blas_dgemv(CblasTrans, 1, A, p, 0, x); /* now x = A^T * (U \ (L \ (A*q)) */
214 gsl_vector_scale(x, -1/(rho*rho));
215 gsl_vector_scale(q, 1/rho);
216 gsl_vector_add(x, q);
217 }
218
219 /*
220 * Message-passing: compute the global sum over all processors of the
221 * contents of w and t. Also, update z.
222 */
223
224 gsl_vector_memcpy(w, x);
225 gsl_vector_add(w, u); // w = x + u
226
227 gsl_blas_ddot(r, r, &send[0]);
228 gsl_blas_ddot(x, x, &send[1]);
229 gsl_blas_ddot(u, u, &send[2]);
230 send[2] /= pow(rho, 2);
231
232 gsl_vector_memcpy(zprev, z);
233
234 // could be reduced to a single Allreduce call by concatenating send to w
235 MPI_Allreduce(w->data, z->data, n, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
236 MPI_Allreduce(send, recv, 3, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
237
238 prires = sqrt(recv[0]); /* sqrt(sum ||r_i||_2^2) */
239 nxstack = sqrt(recv[1]); /* sqrt(sum ||x_i||_2^2) */
240 nystack = sqrt(recv[2]); /* sqrt(sum ||y_i||_2^2) */
241
242 gsl_vector_scale(z, 1/N);
243 soft_threshold(z, lambda/(N*rho));
244
245 /* Termination checks */
246
247 /* dual residual */
248 gsl_vector_memcpy(zdiff, z);
249 gsl_vector_sub(zdiff, zprev);
250 dualres = sqrt(N) * rho * gsl_blas_dnrm2(zdiff); /* ||s^k||_2^2 = N rho^2 ||z - zprev||_2^2 */
251
252 /* compute primal and dual feasibility tolerances */
253 eps_pri = sqrt(n*N)*ABSTOL + RELTOL * fmax(nxstack, sqrt(N)*gsl_blas_dnrm2(z));
254 eps_dual = sqrt(n*N)*ABSTOL + RELTOL * nystack;
255
256 if (rank == 0) {
257 printf("%3d %10.4f %10.4f %10.4f %10.4f %10.4f\n", iter,
258 prires, eps_pri, dualres, eps_dual, objective(A, b, lambda, z));
259 }
260
261 if (prires <= eps_pri && dualres <= eps_dual) {
262 break;
263 }
264
265 /* Compute residual: r = x - z */
266 gsl_vector_memcpy(r, x);
267 gsl_vector_sub(r, z);
268
269 iter++;
270 }
271
272 /* Have the master write out the results to disk */
273 if (rank == 0) {
274 f = fopen("data/solution.dat", "w");
275 gsl_vector_fprintf(f, z, "%lf");
276 fclose(f);
277 }
278
279 MPI_Finalize(); /* Shut down the MPI execution environment */
280
281 /* Clear memory */
282 gsl_matrix_free(A);
283 gsl_matrix_free(L);
284 gsl_vector_free(b);
285 gsl_vector_free(x);
286 gsl_vector_free(u);
287 gsl_vector_free(z);
288 gsl_vector_free(y);
289 gsl_vector_free(r);
290 gsl_vector_free(w);
291 gsl_vector_free(zprev);
292 gsl_vector_free(zdiff);
293 gsl_vector_free(q);
294 gsl_vector_free(Aq);
295 gsl_vector_free(Atb);
296 gsl_vector_free(p);
297
298 return EXIT_SUCCESS;
299 }
300
301 double objective(gsl_matrix *A, gsl_vector *b, double lambda, gsl_vector *z) {
302 double obj = 0;
303 gsl_vector *Azb = gsl_vector_calloc(A->size1);
304 gsl_blas_dgemv(CblasNoTrans, 1, A, z, 0, Azb);
305 gsl_vector_sub(Azb, b);
306 double Azb_nrm2;
307 gsl_blas_ddot(Azb, Azb, &Azb_nrm2);
308 obj = 0.5 * Azb_nrm2 + lambda * gsl_blas_dasum(z);
309 gsl_vector_free(Azb);
310 return obj;
311 }
312
313 void soft_threshold(gsl_vector *v, double k) {
314 double vi;
315 for (int i = 0; i < v->size; i++) {
316 vi = gsl_vector_get(v, i);
317 if (vi > k) { gsl_vector_set(v, i, vi - k); }
318 else if (vi < -k) { gsl_vector_set(v, i, vi + k); }
319 else { gsl_vector_set(v, i, 0); }
320 }
321 }
syntax highlighted by Code2HTML, v. 0.9.1