scalapack_connector.h Source File

LibRPA: scalapack_connector.h Source File
LibRPA
scalapack_connector.h
1 #ifndef SCALAPACK_CONNECTOR_H
2 #define SCALAPACK_CONNECTOR_H
3 
4 #include <complex>
5 #include "interface/blacs_scalapack.h"
6 #include "lapack_connector.h"
7 
9 {
10 public:
11  // indexing functions adapted from ScaLAPACK TOOLS
12  inline static int indxg2p(const int &indxglob, const int &nb, const int &iproc, const int &isrcproc, const int &nprocs)
13  {
14  return (isrcproc + indxglob / nb) % nprocs;
15  }
16  inline static int indxg2l(const int &indxglob, const int &nb, const int &iproc, const int &isrcproc, const int &nprocs)
17  {
18  return nb * (indxglob/ (nb * nprocs)) + indxglob % nb;
19  }
20  inline static int indxl2g(const int &indxloc, const int &nb, const int &iproc, const int &isrcproc, const int &nprocs)
21  {
22  return nprocs * nb * (indxloc / nb) + indxloc % nb +
23  ((nprocs + iproc - isrcproc) % nprocs) * nb;
24  }
25 
26  static void transpose_desc( int desc_T[9], const int desc[9] )
27  {
28  desc_T[0] = desc[0];
29  desc_T[1] = desc[1];
30  desc_T[2] = desc[3]; desc_T[3] = desc[2];
31  desc_T[4] = desc[5]; desc_T[5] = desc[4];
32  desc_T[6] = desc[6]; desc_T[7] = desc[7];
33  desc_T[8] = desc[8];
34  }
35 
36  static void blacs_gridinit( int &ictxt, const char order, const int nprow, const int npcol )
37  {
38  blacs_gridinit_(&ictxt, &order, &nprow, &npcol);
39  }
40 
41  static void blacs_gridinfo( const int &ictxt, int &nprow, int &npcol, int &myprow, int &mypcol )
42  {
43  blacs_gridinfo_( &ictxt, &nprow, &npcol, &myprow, &mypcol );
44  }
45 
46  static int numroc( const int n, const int nb, const int iproc, const int srcproc, const int nprocs )
47  {
48  return numroc_(&n, &nb, &iproc, &srcproc, &nprocs);
49  }
50 
51  static void descinit(
52  int *desc,
53  const int m, const int n, const int mb, const int nb, const int irsrc, const int icsrc,
54  const int ictxt, const int lld, int &info )
55  {
56  descinit_(desc, &m, &n, &mb, &nb, &irsrc, &icsrc, &ictxt, &lld, &info);
57 // descinit_(desc, &n, &m, &nb, &mb, &irsrc, &icsrc, &ictxt, &lld, &info);
58  }
59 
60  // C = a * A.? * B.? + b * C
61  static void pgemm(
62  const char transa, const char transb,
63  const int M, const int N, const int K,
64  const double alpha,
65  const double *A, const int IA, const int JA, const int *DESCA,
66  const double *B, const int IB, const int JB, const int *DESCB,
67  const double beta,
68  double *C, const int IC, const int JC, const int *DESCC)
69  {
70  int DESCA_T[9], DESCB_T[9], DESCC_T[9];
71  transpose_desc( DESCA_T, DESCA );
72  transpose_desc( DESCB_T, DESCB );
73  transpose_desc( DESCC_T, DESCC );
74  pdgemm_(
75  &transb, &transa,
76  &N, &M, &K,
77  &alpha,
78  B, &JB, &IB, DESCB_T,
79  A, &JA, &IA, DESCA_T,
80  &beta,
81  C, &JC, &IC, DESCC_T);
82  // pdgemm_(
83  // &transa, &transb,
84  // &M, &N, &K,
85  // &alpha,
86  // A, &JA, &IA, DESCA,
87  // B, &JB, &IB, DESCB,
88  // &beta,
89  // C, &JC, &IC, DESCC);
90  }
91 
92  static inline
93  void pdgetrf(int m, int n, matrix &a,int ia, int ja, int *desca, int *ipiv, int *info)
94  {
95  double *aux = LapackConnector::transpose_matrix(a, n, m);
96  pdgetrf_( &m, &n, aux, &ia, &ja, desca, ipiv, info);
97  LapackConnector::transpose_matrix(aux, a, n, m);
98  delete[] aux;
99  return;
100  }
101 
102  static inline
103  void pscal_f(const int &N, const float &alpha, float *X,
104  const int &IX, const int &JX, const int *DESCX,
105  const int &INCX)
106  {
107  psscal_(&N, &alpha, X, &IX, &JX, DESCX, &INCX);
108  }
109 
110  static inline
111  void pscal_f(const int &N, const double &alpha, double *X,
112  const int &IX, const int &JX, const int *DESCX,
113  const int &INCX)
114  {
115  pdscal_(&N, &alpha, X, &IX, &JX, DESCX, &INCX);
116  }
117 
118  static inline
119  void pscal_f(const int &N, const std::complex<float> &alpha, std::complex<float> *X,
120  const int &IX, const int &JX, const int *DESCX,
121  const int &INCX)
122  {
123  pcscal_(&N, &alpha, X, &IX, &JX, DESCX, &INCX);
124  }
125 
126  static inline
127  void pscal_f(const int &N, const std::complex<double> &alpha, std::complex<double> *X,
128  const int &IX, const int &JX, const int *DESCX,
129  int &INCX)
130  {
131  pzscal_(&N, &alpha, X, &IX, &JX, DESCX, &INCX);
132  }
133 
134  static inline
135  void pscal_f(const int &N, const float &alpha, std::complex<float> *X,
136  const int &IX, const int &JX, const int *DESCX,
137  const int &INCX)
138  {
139  pcsscal_(&N, &alpha, X, &IX, &JX, DESCX, &INCX);
140  }
141 
142  static inline
143  void pscal_f(const int &N, const double &alpha, std::complex<double> *X,
144  const int &IX, const int &JX, const int *DESCX,
145  const int &INCX)
146  {
147  pzdscal_(&N, &alpha, X, &IX, &JX, DESCX, &INCX);
148  }
149 
150  static inline
151  float pdot_f(const int &N, const float *X, const int &IX,
152  const int &JX, const int *DESCX, const int &INCX,
153  const float *Y, const int &IY, const int &JY,
154  const int *DESCY, const int &INCY)
155  {
156  float result = 0.0;
157  psdot_(&N, &result, X, &IX, &JX, DESCX, &INCX, Y, &IY, &JY, DESCY, &INCY);
158  return result;
159  }
160 
161  static inline
162  double pdot_f(const int &N, const double *X, const int &IX,
163  const int &JX, const int *DESCX, const int &INCX,
164  const double *Y, const int &IY, const int &JY,
165  const int *DESCY, const int &INCY)
166  {
167  double result = 0.0;
168  pddot_(&N, &result, X, &IX, &JX, DESCX, &INCX, Y, &IY, &JY, DESCY, &INCY);
169  return result;
170  }
171 
172  static inline
173  std::complex<float> pdot_f(const int &N, const std::complex<float> *X,
174  const int &IX, const int &JX, const int *DESCX,
175  const int &INCX, const std::complex<float> *Y,
176  const int &IY, const int &JY, const int *DESCY,
177  const int &INCY)
178  {
179  std::complex<float> result = 0.0;
180  pcdotu_(&N, &result, X, &IX, &JX, DESCX, &INCX, Y, &IY, &JY, DESCY, &INCY);
181  return result;
182  }
183 
184  static inline
185  std::complex<float> pdotc_f(const int &N, const std::complex<float> *X,
186  const int &IX, const int &JX, const int *DESCX,
187  const int &INCX, const std::complex<float> *Y,
188  const int &IY, const int &JY, const int *DESCY,
189  const int &INCY)
190  {
191  std::complex<float> result = 0.0;
192  pcdotc_(&N, &result, X, &IX, &JX, DESCX, &INCX, Y, &IY, &JY, DESCY, &INCY);
193  return result;
194  }
195 
196  static inline
197  std::complex<double> pdot_f(const int &N, const std::complex<double> *X,
198  const int &IX, const int &JX, const int *DESCX,
199  const int &INCX, const std::complex<double> *Y,
200  const int &IY, const int &JY, const int *DESCY,
201  const int &INCY)
202  {
203  std::complex<double> result = 0.0;
204  pzdotu_(&N, &result, X, &IX, &JX, DESCX, &INCX, Y, &IY, &JY, DESCY, &INCY);
205  return result;
206  }
207 
208  static inline
209  std::complex<double> pdotc_f(const int &N, const std::complex<double> *X,
210  const int &IX, const int &JX, const int *DESCX,
211  const int &INCX, const std::complex<double> *Y,
212  const int &IY, const int &JY, const int *DESCY,
213  const int &INCY)
214  {
215  std::complex<double> result = 0.0;
216  pzdotc_(&N, &result, X, &IX, &JX, DESCX, &INCX, Y, &IY, &JY, DESCY, &INCY);
217  return result;
218  }
219 
220  static inline
221  void pgemv_f(const char &transa, const int &M, const int &N, const float &alpha,
222  const float *A, const int &IA, const int &JA, const int *DESCA,
223  const float *X, const int &IX, const int &JX, const int *DESCX, const int &INCX,
224  const float &beta,
225  float *Y, const int &IY, const int &JY, const int *DESCY, const int &INCY)
226  {
227  psgemv_(&transa, &M, &N, &alpha,
228  A, &IA, &JA, DESCA,
229  X, &IX, &JX, DESCX, &INCX,
230  &beta, Y, &IY, &JY, DESCY, &INCY);
231  }
232 
233  static inline
234  void pgemv_f(const char &transa, const int &M, const int &N, const double &alpha,
235  const double *A, const int &IA, const int &JA, const int *DESCA,
236  const double *X, const int &IX, const int &JX, const int *DESCX, const int &INCX,
237  const double &beta,
238  double *Y, const int &IY, const int &JY, const int *DESCY, const int &INCY)
239  {
240  pdgemv_(&transa, &M, &N, &alpha,
241  A, &IA, &JA, DESCA,
242  X, &IX, &JX, DESCX, &INCX,
243  &beta, Y, &IY, &JY, DESCY, &INCY);
244  }
245 
246  static inline
247  void pgemv_f(const char &transa, const int &M, const int &N, const std::complex<float> &alpha,
248  const std::complex<float> *A, const int &IA, const int &JA, const int *DESCA,
249  const std::complex<float> *X, const int &IX, const int &JX, const int *DESCX, const int &INCX,
250  const std::complex<float> &beta,
251  std::complex<float> *Y, const int &IY, const int &JY, const int *DESCY, const int &INCY)
252  {
253  pcgemv_(&transa, &M, &N, &alpha,
254  A, &IA, &JA, DESCA,
255  X, &IX, &JX, DESCX, &INCX,
256  &beta, Y, &IY, &JY, DESCY, &INCY);
257  }
258 
259  static inline
260  void pgemv_f(const char &transa, const int &M, const int &N, const std::complex<double> &alpha,
261  const std::complex<double> *A, const int &IA, const int &JA, const int *DESCA,
262  const std::complex<double> *X, const int &IX, const int &JX, const int *DESCX, const int &INCX,
263  const std::complex<double> &beta,
264  std::complex<double> *Y, const int &IY, const int &JY, const int *DESCY, const int &INCY)
265  {
266  pzgemv_(&transa, &M, &N, &alpha,
267  A, &IA, &JA, DESCA,
268  X, &IX, &JX, DESCX, &INCX,
269  &beta, Y, &IY, &JY, DESCY, &INCY);
270  }
271 
272  static inline
273  void pgemm_f(const char &transa, const char &transb,
274  const int &M, const int &N, const int &K,
275  const float &alpha,
276  const float *A, const int &IA, const int &JA, const int *DESCA,
277  const float *B, const int &IB, const int &JB, const int *DESCB,
278  const float &beta,
279  float *C, const int &IC, const int &JC, const int *DESCC)
280  {
281  psgemm_(&transa, &transb, &M, &N, &K, &alpha,
282  A, &IA, &JA, DESCA,
283  B, &IB, &JB, DESCB,
284  &beta,
285  C, &IC, &JC, DESCC);
286  }
287 
288  static inline
289  void pgemm_f(const char &transa, const char &transb,
290  const int &M, const int &N, const int &K,
291  const double &alpha,
292  const double *A, const int &IA, const int &JA, const int *DESCA,
293  const double *B, const int &IB, const int &JB, const int *DESCB,
294  const double &beta,
295  double *C, const int &IC, const int &JC, const int *DESCC)
296  {
297  pdgemm_(&transa, &transb, &M, &N, &K, &alpha,
298  A, &IA, &JA, DESCA,
299  B, &IB, &JB, DESCB,
300  &beta,
301  C, &IC, &JC, DESCC);
302  }
303 
304  static inline
305  void pgemm_f(const char &transa, const char &transb,
306  const int &M, const int &N, const int &K,
307  const std::complex<float> &alpha,
308  const std::complex<float> *A, const int &IA, const int &JA, const int *DESCA,
309  const std::complex<float> *B, const int &IB, const int &JB, const int *DESCB,
310  const std::complex<float> &beta,
311  std::complex<float> *C, const int &IC, const int &JC, const int *DESCC)
312  {
313  pcgemm_(&transa, &transb, &M, &N, &K, &alpha,
314  A, &IA, &JA, DESCA,
315  B, &IB, &JB, DESCB,
316  &beta,
317  C, &IC, &JC, DESCC);
318  }
319 
320  static inline
321  void pgemm_f(const char &transa, const char &transb,
322  const int &M, const int &N, const int &K,
323  const std::complex<double> &alpha,
324  const std::complex<double> *A, const int &IA, const int &JA, const int *DESCA,
325  const std::complex<double> *B, const int &IB, const int &JB, const int *DESCB,
326  const std::complex<double> &beta,
327  std::complex<double> *C, const int &IC, const int &JC, const int *DESCC)
328  {
329  pzgemm_(&transa, &transb, &M, &N, &K, &alpha,
330  A, &IA, &JA, DESCA,
331  B, &IB, &JB, DESCB,
332  &beta,
333  C, &IC, &JC, DESCC);
334  }
335 
336  static inline
337  void pgemr2d_f(const int m, const int n,
338  const float *a, const int ia, const int ja, const int *desca,
339  float *b, const int ib, const int jb, const int *descb,
340  const int ictxt)
341  {
342  psgemr2d_(&m, &n, a, &ia, &ja, desca, b, &ib, &jb, descb, &ictxt);
343  }
344 
345  static inline
346  void pgemr2d_f(const int m, const int n,
347  const double *a, const int ia, const int ja, const int *desca,
348  double *b, const int ib, const int jb, const int *descb,
349  const int ictxt)
350  {
351  pdgemr2d_(&m, &n, a, &ia, &ja, desca, b, &ib, &jb, descb, &ictxt);
352  }
353 
354  static inline
355  void pgemr2d_f(const int m, const int n,
356  const std::complex<float> *a, const int ia, const int ja, const int *desca,
357  std::complex<float> *b, const int ib, const int jb, const int *descb,
358  const int ictxt)
359  {
360  pcgemr2d_(&m, &n, a, &ia, &ja, desca, b, &ib, &jb, descb, &ictxt);
361  }
362 
363  static inline
364  void pgemr2d_f(const int m, const int n,
365  const std::complex<double> *a, const int ia, const int ja, const int *desca,
366  std::complex<double> *b, const int ib, const int jb, const int *descb,
367  const int ictxt)
368  {
369  pzgemr2d_(&m, &n, a, &ia, &ja, desca, b, &ib, &jb, descb, &ictxt);
370  }
371 
372  static inline
373  void psyev_f(const char &jobz, const char &uplo,
374  const int &n, float *A, const int &ia, const int &ja, const int *desca,
375  float *W, float *Z, const int &iz, const int &jz, const int *descz,
376  float *work, const int &lwork, float *rwork, const int &lrwork, int &info)
377  {
378  pssyev_(&jobz, &uplo, &n, A, &ia, &ja, desca,
379  W, Z, &iz, &jz, descz,
380  work, &lwork, rwork, &lrwork, &info);
381  }
382  static inline
383  void psyev_f(const char &jobz, const char &uplo,
384  const int &n, double *A, const int &ia, const int &ja, const int *desca,
385  double *W, double *Z, const int &iz, const int &jz, const int *descz,
386  double *work, const int &lwork, double *rwork, const int &lrwork, int &info)
387  {
388  pdsyev_(&jobz, &uplo, &n, A, &ia, &ja, desca,
389  W, Z, &iz, &jz, descz,
390  work, &lwork, rwork, &lrwork, &info);
391  }
392  static inline
393  void pheev_f(const char &jobz, const char &uplo,
394  const int &n, std::complex<float> *A, const int &ia, const int &ja, const int *desca,
395  float *W, std::complex<float> *Z, const int &iz, const int &jz, const int *descz,
396  std::complex<float> *work, const int &lwork, float *rwork, const int &lrwork, int &info)
397  {
398  pcheev_(&jobz, &uplo, &n, A, &ia, &ja, desca,
399  W, Z, &iz, &jz, descz,
400  work, &lwork, rwork, &lrwork, &info);
401  }
402  static inline
403  void pheev_f(const char &jobz, const char &uplo,
404  const int &n, std::complex<double> *A, const int &ia, const int &ja, const int *desca,
405  double *W, std::complex<double> *Z, const int &iz, const int &jz, const int *descz,
406  std::complex<double> *work, const int &lwork, double *rwork, const int &lrwork, int &info)
407  // const char, const char, const int, std::complex<double> *, int, int, const int [9], double *, std::complex<double> *, int, int, const int [9], double *, int, std::complex<double> *, int, int
408  {
409  pzheev_(&jobz, &uplo, &n, A, &ia, &ja, desca,
410  W, Z, &iz, &jz, descz,
411  work, &lwork, rwork, &lrwork, &info);
412  }
413 
414  static inline
415  void pgetrf_f(const int &m, const int &n, float *a, const int &ia, const int &ja, const int *desca, int *ipiv, int &info)
416  {
417  psgetrf_(&m, &n, a, &ia, &ja, desca, ipiv, &info);
418  }
419 
420  static inline
421  void pgetrf_f(const int &m, const int &n, double *a, const int &ia, const int &ja, const int *desca, int *ipiv, int &info)
422  {
423  pdgetrf_(&m, &n, a, &ia, &ja, desca, ipiv, &info);
424  }
425 
426  static inline
427  void pgetrf_f(const int &m, const int &n, std::complex<float> *a, const int &ia, const int &ja, const int *desca, int *ipiv, int &info)
428  {
429  pcgetrf_(&m, &n, a, &ia, &ja, desca, ipiv, &info);
430  }
431 
432  static inline
433  void pgetrf_f(const int &m, const int &n, std::complex<double> *a, const int &ia, const int &ja, const int *desca, int *ipiv, int &info)
434  {
435  pzgetrf_(&m, &n, a, &ia, &ja, desca, ipiv, &info);
436  }
437 
438  static inline
439  void pgetri_f(const int &n, float *a, const int &ia, const int &ja, const int *desca, int *ipiv, float *work, const int &lwork, int *iwork, const int &liwork, int &info)
440  {
441  psgetri_(&n, a, &ia, &ja, desca, ipiv, work, &lwork, iwork, &liwork, &info);
442  }
443 
444  static inline
445  void pgetri_f(const int &n, double *a, const int &ia, const int &ja, const int *desca, int *ipiv, double *work, const int &lwork, int *iwork, const int &liwork, int &info)
446  {
447  pdgetri_(&n, a, &ia, &ja, desca, ipiv, work, &lwork, iwork, &liwork, &info);
448  }
449 
450  static inline
451  void pgetri_f(const int &n, std::complex<float> *a, const int &ia, const int &ja, const int *desca, int *ipiv, std::complex<float> *work, const int &lwork, int *iwork, const int &liwork, int &info)
452  {
453  pcgetri_(&n, a, &ia, &ja, desca, ipiv, work, &lwork, iwork, &liwork, &info);
454  }
455 
456  static inline
457  void pgetri_f(const int &n, std::complex<double> *a, const int &ia, const int &ja, const int *desca, int *ipiv, std::complex<double> *work, const int &lwork, int *iwork, const int &liwork, int &info)
458  {
459  pzgetri_(&n, a, &ia, &ja, desca, ipiv, work, &lwork, iwork, &liwork, &info);
460  }
461 
462 };
463 
464 #endif
Definition: scalapack_connector.h:9
Definition: matrix.h:23