lapack_connector.h Source File

LibRPA: lapack_connector.h Source File
LibRPA
lapack_connector.h
1 #ifndef LAPACKCONNECTOR_HPP
2 #define LAPACKCONNECTOR_HPP
3 
4 #include <vector>
5 #include <new>
6 #include <stdexcept>
7 #include <iostream>
8 #include <cassert>
9 #include "base_utility.h"
10 #include "matrix.h"
11 #include "complexmatrix.h"
12 #include "interface/blas_lapack.h"
13 
14 // Class LapackConnector provide the connector to fortran lapack routine.
15 // The entire function in this class are static and inline function.
16 // Usage example: LapackConnector::functionname(parameter list).
18 {
19 public:
20  // Transpose an row-major array to the fortran-form colum-major array.
21  template <typename T>
22  static inline T* transpose(const T* a, const int &n, const int &lda, bool conjugate = false)
23  {
24  T* a_fort = new T[lda*n];
25  if (is_complex<T>() && conjugate)
26  for (int i = 0; i < n; i++)
27  for (int j = 0; j < lda; j++)
28  a_fort[i*lda+j] = get_conj(a[j*n+i]);
29  else
30  for (int i = 0; i < n; i++)
31  for (int j = 0; j < lda; j++)
32  a_fort[i*lda+j] = a[j*n+i];
33  return a_fort;
34  }
35 
36  template <typename T>
37  static inline void transpose(const T* a, T* a_trans, const int &n, const int &lda, bool conjugate = false)
38  {
39  if (is_complex<T>() && conjugate)
40  for (int i = 0; i < n; i++)
41  for (int j = 0; j < lda; j++)
42  a_trans[j*n+i] = get_conj(a[i*lda+j]);
43  else
44  for (int i = 0; i < n; i++)
45  for (int j = 0; j < lda; j++)
46  a_trans[j*n+i] = a[i*lda+j];
47  }
48 
49  // Transpose the complex matrix to the fortran-form real-complex array.
50  static inline
51  complex<double>* transpose(const ComplexMatrix& a, const int n, const int lda)
52  {
53  complex<double>* aux = new complex<double>[lda*n];
54  for (int i = 0; i < n; ++i)
55  {
56  for (int j = 0; j < lda; ++j)
57  {
58  aux[i*lda+j] = a(j,i); // aux[i*lda+j] means aux[i][j] in semantic, not in syntax!
59  }
60  }
61  return aux;
62  }
63  // Transpose the matrix to the fortran-form real array.
64  static inline
65  double * transpose_matrix(const matrix& a, const int n, const int lda)
66  {
67  double * aux = new double [lda*n];
68  std::cout << " lda=" << lda << " n=" << n << std::endl;
69  for (int i=0; i<n; i++)
70  {
71  for (int j=0; j<lda; j++)
72  {
73  aux[i*lda+j] = a(j,i); // aux[i*lda+j] means aux[i][j] in semantic, not in syntax!
74  }
75  }
76  return aux;
77  }
78 
79  // Transpose the fortran-form real-complex array to the complex matrix.
80  static inline
81  void transpose(const complex<double>* aux, ComplexMatrix& a, const int n, const int lda)
82  {
83  for (int i = 0; i < n; ++i)
84  {
85  for (int j = 0; j < lda; ++j)
86  {
87  a(j, i) = aux[i*lda+j]; // aux[i*lda+j] means aux[i][j] in semantic, not in syntax!
88  }
89  }
90  }
91 
92  // Transpose the fortran-form real array to the matrix.
93  static inline
94  void transpose_matrix(const double *aux, matrix &a, const int n, const int lda)
95  {
96  for (int i = 0; i < n; ++i)
97  {
98  for (int j = 0; j < lda; ++j)
99  {
100  a(j, i) = aux[i*lda+j]; // aux[i*lda+j] means aux[i][j] in semantic, not in syntax!
101  }
102  }
103  }
104 
105  // Peize Lin add 2015-12-27
106  static inline
107  char change_side(const char &side)
108  {
109  switch(side)
110  {
111  case 'L': return 'R';
112  case 'R': return 'L';
113  default: throw invalid_argument("Side must be 'L' or 'R'");
114  }
115  }
116 
117  // Peize Lin add 2015-12-27
118  static inline
119  char change_uplo(const char &uplo)
120  {
121  switch(uplo)
122  {
123  case 'U': return 'L';
124  case 'L': return 'U';
125  default: throw invalid_argument("uplo must be 'U' or 'L'");
126  }
127  }
128 
129  // Peize Lin add 2019-04-14
130  static inline
131  char change_trans_NT(const char &trans)
132  {
133  switch(trans)
134  {
135  case 'N': return 'T';
136  case 'T': return 'N';
137  default: throw invalid_argument("trans must be 'N' or 'T'");
138  }
139  }
140  // Peize Lin add 2019-04-14
141  static inline
142  char change_trans_NC(const char &trans)
143  {
144  switch(trans)
145  {
146  case 'N': return 'C';
147  case 'C': return 'N';
148  default: throw invalid_argument("trans must be 'N' or 'C'");
149  }
150  }
151 
152 public:
153 
154  static inline
155  int ilaenv( int ispec, const char *name,const char *opts,const int n1,const int n2,
156  const int n3,const int n4)
157  {
158  const int nb = ilaenv_(&ispec, name, opts, &n1, &n2, &n3, &n4);
159  return nb;
160  }
161 
162  static inline
163  void zgesv (const int n,const int nrhs,ComplexMatrix &A,const int lda,
164  const int *ipiv,complex<double> *B,const int ldb,int* info)
165  {
166  complex<double> *aux = LapackConnector::transpose(A, n, lda);
167  zgesv_(&n, &nrhs, aux, &lda, ipiv, B, &ldb, info);
168  LapackConnector::transpose(aux, A, n, lda);
169  }
170 
171  // wrap function of fortran lapack routine zhegv.
172  static inline
173  void zhegv( const int itype,const char jobz,const char uplo,const int n,ComplexMatrix& a,
174  const int lda,ComplexMatrix& b,const int ldb,double* w,complex<double>* work,
175  int lwork,double* rwork,int info)
176  { // Transpose the complex matrix to the fortran-form real-complex array.
177  complex<double>* aux = LapackConnector::transpose(a, n, lda);
178  complex<double>* bux = LapackConnector::transpose(b, n, ldb);
179 
180  // call the fortran routine
181  zhegv_(&itype, &jobz, &uplo, &n, aux, &lda, bux, &ldb, w, work, &lwork, rwork, &info);
182  // Transpose the fortran-form real-complex array to the complex matrix.
183  LapackConnector::transpose(aux, a, n, lda);
184  LapackConnector::transpose(bux, b, n, ldb);
185  // free the memory.
186  delete[] aux;
187  delete[] bux;
188  }
189 
190 
191  // calculate the selected eigenvalues.
192  // mohan add 2010/03/21
193  // static inline
194  // void sspgvx(const int itype, const char jobz,const char range,const char uplo,
195  // const int n,const matrix &ap,const matrix &bp,const double vl,
196  // const int vu, const int il, const int iu,const double abstol,
197  // const int m,double* w,matrix &z,const int ldz,
198  // double *work,int* iwork,int* ifail,int* info)
199  // {
200  // // Transpose the complex matrix to the fortran-form real*16 array.
201  // double* aux = LapackConnector::transpose_matrix(ap, n, n);
202  // double* bux = LapackConnector::transpose_matrix(bp, n, n);
203  // double* zux = new double[n*iu];
204  //
205  // // call the fortran routine
206  // sspgvx_(&itype, &jobz, &range, &uplo,
207  // &n, aux, bux, &vl,
208  // &vu, &il,&iu, &abstol,
209  // &m, w, zux, &ldz,
210  // work, iwork, ifail, info);
211  //
212  // // Transpose the fortran-form real*16 array to the matrix
213  // LapackConnector::transpose_matrix(zux, z, iu, n);
214  //
215  // // free the memory.
216  // delete[] aux;
217  // delete[] bux;
218  // delete[] zux;
219  // }
220 
221  // calculate the eigenvalues and eigenfunctions of a real symmetric matrix.
222  static inline
223  void dsygv( const int itype,const char jobz,const char uplo,const int n,matrix& a,
224  const int lda,matrix& b,const int ldb,double* w,double* work,
225  int lwork,int *info)
226  { // Transpose the complex matrix to the fortran-form real-complex array.
227  double* aux = new double[lda*n];
228  double* bux = new double[lda*n];
229  for (int i=0; i<n; i++)
230  {
231  for (int j=0; j<lda; j++)
232  {
233  aux[i*lda+j] = a(j,i); // aux[i*lda+j] means aux[i][j] in semantic, not in syntax!
234  bux[i*lda+j] = b(j,i);
235  }
236  }
237  // call the fortran routine
238  dsygv_(&itype, &jobz, &uplo, &n, aux, &lda, bux, &ldb, w, work, &lwork, info);
239  for (int i=0; i<n; i++)
240  {
241  for(int j=0; j<lda; ++j)
242  {
243  a(j,i)=aux[i*lda+j];
244  b(j,i)=bux[i*lda+j];
245  }
246  }
247 
248  // free the memory.
249  delete[] aux;
250  delete[] bux;
251  }
252 
253  // wrap function of fortran lapack routine dsyev.
254  static inline
255  void dsyev(const char jobz,const char uplo,const int n,double *a,
256  const int lda,double *w,double *work,const int lwork, int &info)
257  {
258  const char uplo_changed = change_uplo(uplo);
259  dsyev_(&jobz, &uplo_changed, &n, a, &lda, w, work, &lwork, &info);
260  }
261  // static inline
262  // void dsyev( const char jobz, const char uplo, matrix &a, double* w, int info ) // Peize Lin update 2017-10-17
263  // {
264  // assert(a.nr==a.nc);
265  // const char uplo_changed = change_uplo(uplo);
266 
267  // double work_tmp;
268  // constexpr int minus_one = -1;
269  // dsyev_(&jobz, &uplo_changed, &a.nr, a.c, &a.nr, w, &work_tmp, &minus_one, &info); // get best lwork
270 
271  // const int lwork = work_tmp;
272  // vector<double> work(std::max(1,lwork));
273  // dsyev_(&jobz, &uplo_changed, &a.nr, a.c, &a.nr, w, VECTOR_TO_PTR(work), &lwork, &info);
274  // }
275  // wrap function of fortran lapack routine zheev.
276  static inline
277  void zheev( const char jobz,
278  const char uplo,
279  const int n,
280  ComplexMatrix& a,
281  const int lda,
282  double* w,
283  complex< double >* work,
284  const int lwork,
285  double* rwork,
286  int *info )
287  { // Transpose the complex matrix to the fortran-form real-complex array.
288  complex<double> *aux = LapackConnector::transpose(a, n, lda);
289  // call the fortran routine
290  zheev_(&jobz, &uplo, &n, aux, &lda, w, work, &lwork, rwork, info);
291  // Transpose the fortran-form real-complex array to the complex matrix.
292  LapackConnector::transpose(aux, a, n, lda);
293  // free the memory.
294  delete[] aux;
295  }
296 
297  // wrap function of fortran lapack routine dsytrf.
298  static inline
299  void dsytrf(char uplo, int n, matrix& a, int lda, int *ipiv,
300  double *work, int lwork , int info)
301  {
302  double * aux = LapackConnector::transpose_matrix(a, n, lda) ;
303  dsytrf_(&uplo, &n, aux, &lda, ipiv, work, &lwork, &info);
304  LapackConnector::transpose_matrix(aux, a, n, lda);
305  }
306 
307  // wrap function of fortran lapack routine dsytri.
308  static inline
309  void dsytri(char uplo, int n, matrix& a, int lda, int *iwork,double * work, int info)
310  {
311  double * aux = LapackConnector::transpose_matrix(a, n, lda) ;
312  dsytri_(&uplo, &n, aux, &lda, iwork, work, &info);
313  LapackConnector::transpose_matrix(aux, a, n, lda);
314  }
315 
316 
317  // wrap function of fortran lapack routine zheev.
318  static inline
319  void zhegvx( const int itype, const char jobz, const char range, const char uplo,
320  const int n, const ComplexMatrix& a, const int lda, const ComplexMatrix& b,
321  const int ldb, const double vl, const double vu, const int il, const int iu,
322  const double abstol, const int m, double* w, ComplexMatrix& z, const int ldz,
323  complex<double>* work, const int lwork, double* rwork, int* iwork,
324  int* ifail, int info)
325  {
326  // Transpose the complex matrix to the fortran-form real-complex array.
327  complex<double>* aux = LapackConnector::transpose(a, n, lda);
328  complex<double>* bux = LapackConnector::transpose(b, n, ldb);
329  complex<double>* zux = new complex<double>[n*iu];// mohan modify 2009-08-02
330  //for(int i=0; i<n*iu; i++) zux[i] = complex<double>(0.0,0.0);
331 
332  // call the fortran routine
333  zhegvx_(&itype, &jobz, &range, &uplo, &n, aux, &lda, bux, &ldb, &vl,
334  &vu, &il,&iu, &abstol, &m, w, zux, &ldz, work, &lwork, rwork, iwork, ifail, &info);
335 
336  // Transpose the fortran-form real-complex array to the complex matrix
337  LapackConnector::transpose(zux, z, iu, n); // mohan modify 2009-08-02
338 
339  // free the memory.
340  delete[] aux;
341  delete[] bux;
342  delete[] zux;
343 
344  }
345  static inline
346  void dgesvd(const char jobu,const char jobvt,const int m,const int n,
347  matrix &a,const int lda,double *s,matrix &u,const int ldu,
348  matrix &vt,const int ldvt,double *work,const int lwork,int info)
349  {
350  //Transpose the matrix to the fortran-form
351 
352  double *aux = LapackConnector::transpose_matrix(a, n, lda);
353  double *uux = LapackConnector::transpose_matrix(u, m, ldu);
354  double *vtux = LapackConnector::transpose_matrix(vt, n, ldvt);
355 
356  dgesvd_(&jobu, &jobvt , &m, &n, aux, &lda, s, uux, &ldu, vtux, &ldvt, work, &lwork, &info);
357 
358  LapackConnector::transpose_matrix(aux, a, n, lda);
359  LapackConnector::transpose_matrix(uux, u, m, ldu);
360  LapackConnector::transpose_matrix(vtux, vt, n, ldvt);
361 
362  delete[] aux;
363  delete[] uux;
364  delete[] vtux;
365 
366  }
367 
368  static inline
369  void zpotrf(char uplo,int n,ComplexMatrix &a,const int lda,int* info)
370  {
371  complex<double> *aux = LapackConnector::transpose(a, n, lda);
372  zpotrf_( &uplo, &n, aux, &lda, info);
373  LapackConnector::transpose(aux, a, n, lda);
374  delete[] aux;
375  return;
376  }
377 
378  static inline
379  void zpotri(char uplo,int n,ComplexMatrix &a,const int lda,int* info)
380  {
381  complex<double> *aux = LapackConnector::transpose(a, n, lda);
382  zpotri_( &uplo, &n, aux, &lda, info);
383  LapackConnector::transpose(aux, a, n, lda);
384  delete[] aux;
385  return;
386  }
387 
388  // static inline
389  // void spotrf(char uplo,int n,matrix &a, int lda,int *info)
390  // {
391  // double *aux = LapackConnector::transpose_matrix(a, n, lda);
392  // for(int i=0; i<4; i++)std::cout << "\n aux=" << aux[i];
393  // spotrf_( &uplo, &n, aux, &lda, info);
394  // for(int i=0; i<4; i++)std::cout << "\n aux=" << aux[i];
395  // LapackConnector::transpose_matrix(aux, a, n, lda);
396  // delete[] aux;
397  // return;
398  // }
399  //
400  // static inline
401  // void spotri(char uplo,int n,matrix &a, int lda, int *info)
402  // {
403  // double *aux = LapackConnector::transpose_matrix(a, n, lda);
404  // for(int i=0; i<4; i++)std::cout << "\n aux=" << aux[i];
405  // spotri_( &uplo, &n, aux, &lda, info);
406  // for(int i=0; i<4; i++)std::cout << "\n aux=" << aux[i];
407  // LapackConnector::transpose_matrix(aux, a, n, lda);
408  // delete[] aux;
409  // return;
410  // }
411 
412  static inline
413  void zgetrf(int m, int n, ComplexMatrix &a, const int lda, int *ipiv, int *info)
414  {
415  complex<double> *aux = LapackConnector::transpose(a, n, lda);
416  zgetrf_( &m, &n, aux, &lda, ipiv, info);
417  LapackConnector::transpose(aux, a, n, lda);
418  delete[] aux;
419  return;
420  }
421  static inline
422  void zgetri(int n, ComplexMatrix &a, int lda, int *ipiv, complex<double> * work, int lwork, int *info)
423  {
424  complex<double> *aux = LapackConnector::transpose(a, n, lda);
425  zgetri_( &n, aux, &lda, ipiv, work, &lwork, info);
426  LapackConnector::transpose(aux, a, n, lda);
427  delete[] aux;
428  return;
429  }
430 
431  static inline void getrf(const int &m, const int &n, float *A, const int &lda, int *ipiv, int &info)
432  {
433  float *a_fort = transpose(A, n, lda);
434  sgetrf_(&m, &n, a_fort, &lda, ipiv, &info);
435  transpose(a_fort, A, n, lda);
436  delete [] a_fort;
437  }
438 
439  static inline void getrf(const int &m, const int &n, double *A, const int &lda, int *ipiv, int &info)
440  {
441  double *a_fort = transpose(A, n, lda);
442  dgetrf_(&m, &n, a_fort, &lda, ipiv, &info);
443  transpose(a_fort, A, n, lda);
444  delete [] a_fort;
445  }
446 
447  static inline void getrf(const int &m, const int &n, std::complex<float> *A, const int &lda, int *ipiv, int &info)
448  {
449  std::complex<float> *a_fort = transpose(A, n, lda);
450  cgetrf_(&m, &n, a_fort, &lda, ipiv, &info);
451  transpose(a_fort, A, n, lda);
452  delete [] a_fort;
453  }
454 
455  static inline void getrf(const int &m, const int &n, std::complex<double> *A, const int &lda, int *ipiv, int &info)
456  {
457  std::complex<double> *a_fort = transpose(A, n, lda);
458  zgetrf_(&m, &n, a_fort, &lda, ipiv, &info);
459  transpose(a_fort, A, n, lda);
460  delete [] a_fort;
461  }
462 
463  static inline void getrf_f(const int &m, const int &n, float *A, const int &lda, int *ipiv, int &info)
464  {
465  sgetrf_(&m, &n, A, &lda, ipiv, &info);
466  }
467 
468  static inline void getrf_f(const int &m, const int &n, double *A, const int &lda, int *ipiv, int &info)
469  {
470  dgetrf_(&m, &n, A, &lda, ipiv, &info);
471  }
472 
473  static inline void getrf_f(const int &m, const int &n, std::complex<float> *A, const int &lda, int *ipiv, int &info)
474  {
475  cgetrf_(&m, &n, A, &lda, ipiv, &info);
476  }
477 
478  static inline void getrf_f(const int &m, const int &n, std::complex<double> *A, const int &lda, int *ipiv, int &info)
479  {
480  zgetrf_(&m, &n, A, &lda, ipiv, &info);
481  }
482 
483  static inline void getri(const int &n, float *A, const int &lda, int *ipiv, float *work, const int &lwork, int &info)
484  {
485  float *a_fort = transpose(A, n, lda);
486  sgetri_(&n, a_fort, &lda, ipiv, work, &lwork, &info);
487  transpose(a_fort, A, n, lda);
488  delete [] a_fort;
489  }
490 
491  static inline void getri(const int &n, double *A, const int &lda, int *ipiv, double *work, const int &lwork, int &info)
492  {
493  double *a_fort = transpose(A, n, lda);
494  dgetri_(&n, a_fort, &lda, ipiv, work, &lwork, &info);
495  transpose(a_fort, A, n, lda);
496  delete [] a_fort;
497  }
498 
499  static inline void getri(const int &n, std::complex<float> *A, const int &lda, int *ipiv, std::complex<float> *work, const int &lwork, int &info)
500  {
501  std::complex<float> *a_fort = transpose(A, n, lda);
502  cgetri_(&n, a_fort, &lda, ipiv, work, &lwork, &info);
503  transpose(a_fort, A, n, lda);
504  delete [] a_fort;
505  }
506 
507  static inline void getri(const int &n, std::complex<double> *A, const int &lda, int *ipiv, std::complex<double> *work, const int &lwork, int &info)
508  {
509  std::complex<double> *a_fort = transpose(A, n, lda);
510  zgetri_(&n, a_fort, &lda, ipiv, work, &lwork, &info);
511  transpose(a_fort, A, n, lda);
512  delete [] a_fort;
513  }
514 
515  static inline void getri_f(const int &n, float *A, const int &lda, int *ipiv, float *work, const int &lwork, int &info)
516  {
517  sgetri_(&n, A, &lda, ipiv, work, &lwork, &info);
518  }
519 
520  static inline void getri_f(const int &n, double *A, const int &lda, int *ipiv, double *work, const int &lwork, int &info)
521  {
522  dgetri_(&n, A, &lda, ipiv, work, &lwork, &info);
523  }
524 
525  static inline void getri_f(const int &n, std::complex<float> *A, const int &lda, int *ipiv, std::complex<float> *work, const int &lwork, int &info)
526  {
527  cgetri_(&n, A, &lda, ipiv, work, &lwork, &info);
528  }
529 
530  static inline void getri_f(const int &n, std::complex<double> *A, const int &lda, int *ipiv, std::complex<double> *work, const int &lwork, int &info)
531  {
532  zgetri_(&n, A, &lda, ipiv, work, &lwork, &info);
533  }
534 
535  // Peize Lin add 2016-07-09
536  static inline
537  void dpotrf( char uplo, const int n, matrix &a, const int lda, int *info )
538  {
539  const char uplo_changed = change_uplo(uplo);
540  dpotrf_( &uplo_changed, &n, a.c, &lda, info );
541  }
542 
543  // Peize Lin add 2016-07-09
544  static inline
545  void dpotri( char uplo, const int n, matrix &a, const int lda, int *info )
546  {
547  const char uplo_changed = change_uplo(uplo);
548  dpotri_( &uplo_changed, &n, a.c, &lda, info);
549  }
550 
551  // Peize Lin add 2016-08-04
552  // y=a*x+y
553  static inline
554  void axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY)
555  {
556  saxpy_(&n, &alpha, X, &incX, Y, &incY);
557  }
558  static inline
559  void axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY)
560  {
561  daxpy_(&n, &alpha, X, &incX, Y, &incY);
562  }
563  static inline
564  void axpy( const int n, const complex<float> alpha, const complex<float> *X, const int incX, complex<float> *Y, const int incY)
565  {
566  caxpy_(&n, &alpha, X, &incX, Y, &incY);
567  }
568  static inline
569  void axpy( const int n, const complex<double> alpha, const complex<double> *X, const int incX, complex<double> *Y, const int incY)
570  {
571  zaxpy_(&n, &alpha, X, &incX, Y, &incY);
572  }
573 
574  // Peize Lin add 2016-08-04
575  // x=a*x
576  static inline
577  void scal( const int n, const float alpha, float *X, const int incX)
578  {
579  sscal_(&n, &alpha, X, &incX);
580  }
581  static inline
582  void scal( const int n, const double alpha, double *X, const int incX)
583  {
584  dscal_(&n, &alpha, X, &incX);
585  }
586  static inline
587  void scal( const int n, const complex<float> alpha, complex<float> *X, const int incX)
588  {
589  cscal_(&n, &alpha, X, &incX);
590  }
591  static inline
592  void scal( const int n, const complex<double> alpha, complex<double> *X, const int incX)
593  {
594  zscal_(&n, &alpha, X, &incX);
595  }
596 
597  // Peize Lin add 2017-10-27
598  // d=x*y
599  static inline
600  float dot( const int n, const float *X, const int incX, const float *Y, const int incY)
601  {
602  return sdot_(&n, X, &incX, Y, &incY);
603  }
604  static inline
605  double dot( const int n, const double *X, const int incX, const double *Y, const int incY)
606  {
607  return ddot_(&n, X, &incX, Y, &incY);
608  }
609 
610  // minyez add 2022-11-15
611  static inline
612  std::complex<float> dot( const int n, const std::complex<float> *X, const int incX, const std::complex<float> *Y, const int incY)
613  {
614  std::complex<float> result;
615  cdotu_(&result, &n, X, &incX, Y, &incY);
616  return result;
617  }
618 
619  static inline
620  std::complex<double> dot( const int n, const std::complex<double> *X, const int incX, const std::complex<double> *Y, const int incY)
621  {
622  std::complex<double> result;
623  zdotu_(&result, &n, X, &incX, Y, &incY);
624  return result;
625  }
626 
627  // minyez add 2022-05-12
628  // matrix-vector product
629  // single-prec version
630  static inline
631  void gemv(const char transa, const int m, const int n, const float alpha, const float *a,
632  const int lda, const float *x, const int incx, const float beta, float *y, const int incy)
633  {
634  char transa_f = change_trans_NT(transa);
635  sgemv_(&transa_f, &n, &m, &alpha, a, &lda, x, &incx, &beta, y, &incy);
636  }
637  // double-prec version
638  static inline
639  void gemv(const char transa, const int m, const int n, const double alpha, const double *a,
640  const int lda, const double *x, const int incx, const double beta, double *y, const int incy)
641  {
642  char transa_f = change_trans_NT(transa);
643  dgemv_(&transa_f, &n, &m, &alpha, a, &lda, x, &incx, &beta, y, &incy);
644  }
645 
646  inline void gemv_f(const char &transa, const int &m, const int &n,
647  const float &alpha, const float *a, const int &lda,
648  const float *x, const int &incx, const float &beta, float *y, const int &incy)
649  {
650  sgemv_(&transa, &m, &n, &alpha, a, &lda, x, &incx, &beta, y, &incy);
651  }
652 
653  inline void gemv_f(const char &transa, const int &m, const int &n,
654  const double &alpha, const double *a, const int &lda,
655  const double *x, const int &incx, const double &beta, double *y, const int &incy)
656  {
657  dgemv_(&transa, &m, &n, &alpha, a, &lda, x, &incx, &beta, y, &incy);
658  }
659 
660  inline void gemv_f(const char &transa, const int &m, const int &n,
661  const std::complex<float> &alpha, const std::complex<float> *a, const int &lda,
662  const std::complex<float> *x, const int &incx, const std::complex<float> &beta, std::complex<float> *y, const int &incy)
663  {
664  cgemv_(&transa, &m, &n, &alpha, a, &lda, x, &incx, &beta, y, &incy);
665  }
666 
667  inline void gemv_f(const char &transa, const int &m, const int &n,
668  const std::complex<double> &alpha, const std::complex<double> *a, const int &lda,
669  const std::complex<double> *x, const int &incx, const std::complex<double> &beta, std::complex<double> *y, const int &incy)
670  {
671  zgemv_(&transa, &m, &n, &alpha, a, &lda, x, &incx, &beta, y, &incy);
672  }
673 
674  // Peize Lin add 2017-10-27, fix bug trans 2019-01-17
675  // C = a * A.? * B.? + b * C
676  static inline
677  void gemm(const char transa, const char transb, const int m, const int n, const int k,
678  const float alpha, const float *a, const int lda, const float *b, const int ldb,
679  const float beta, float *c, const int ldc)
680  {
681  sgemm_(&transb, &transa, &n, &m, &k,
682  &alpha, b, &ldb, a, &lda,
683  &beta, c, &ldc);
684  }
685 
686  static inline
687  void gemm(const char transa, const char transb, const int m, const int n, const int k,
688  const double alpha, const double *a, const int lda, const double *b, const int ldb,
689  const double beta, double *c, const int ldc)
690  {
691  dgemm_(&transb, &transa, &n, &m, &k,
692  &alpha, b, &ldb, a, &lda,
693  &beta, c, &ldc);
694  }
695 
696  static inline
697  void gemm(const char transa, const char transb, const int m,
698  const int n, const int k, const complex<float> alpha,
699  const complex<float> *a, const int lda,
700  const complex<float> *b, const int ldb,
701  const complex<float> beta, complex<float> *c, const int ldc)
702  {
703  cgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda,
704  &beta, c, &ldc);
705  }
706 
707  static inline
708  void gemm(const char transa, const char transb, const int m,
709  const int n, const int k, const complex<double> alpha,
710  const complex<double> *a, const int lda,
711  const complex<double> *b, const int ldb,
712  const complex<double> beta, complex<double> *c, const int ldc)
713  {
714  zgemm_(&transb, &transa, &n, &m, &k,
715  &alpha, b, &ldb, a, &lda,
716  &beta, c, &ldc);
717  }
718 
719  static inline
720  void gemm_f(const char transa, const char transb,
721  const int m, const int n, const int k,
722  const float alpha, const float *a,
723  const int lda, const float *b, const int ldb,
724  const float beta, float *c, const int ldc)
725  {
726  sgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb,
727  &beta, c, &ldc);
728  }
729 
730  static inline void gemm_f(const char transa, const char transb,
731  const int m, const int n, const int k,
732  const complex<float> alpha,
733  const complex<float> *a, const int lda,
734  const complex<float> *b, const int ldb,
735  const complex<float> beta,
736  complex<float> *c, const int ldc)
737  {
738  cgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c,
739  &ldc);
740  }
741 
742  static inline void gemm_f(const char transa, const char transb,
743  const int m, const int n, const int k,
744  const double alpha, const double *a,
745  const int lda, const double *b, const int ldb,
746  const double beta, double *c, const int ldc)
747  {
748  dgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb,
749  &beta, c, &ldc);
750  }
751 
752  static inline void gemm_f(const char transa, const char transb,
753  const int m, const int n, const int k,
754  const complex<double> alpha,
755  const complex<double> *a, const int lda,
756  const complex<double> *b, const int ldb,
757  const complex<double> beta,
758  complex<double> *c, const int ldc)
759  {
760  zgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c,
761  &ldc);
762  }
763 
764  // eigenvector of hermitian matrix, row-major
765  static inline
766  void heev(const char &jobz, const char &uplo, const int &n,
767  std::complex<float> *a, const int &lda, float *w,
768  std::complex<float> *work, const int &lwork, float *rwork, int &info)
769  {
770  complex<float> *aux = LapackConnector::transpose(a, n, lda);
771  // call the fortran routine
772  cheev_(&jobz, &uplo, &n, aux, &lda, w, work, &lwork, rwork, &info);
773  // Transpose the fortran-form real-complex array to the complex matrix.
774  LapackConnector::transpose(aux, a, n, lda);
775  // free the memory.
776  delete[] aux;
777  }
778 
779  static inline
780  void heev(const char &jobz, const char &uplo, const int &n,
781  std::complex<double> *a, const int &lda, double *w,
782  std::complex<double> *work, const int &lwork, double *rwork, int &info)
783  {
784  complex<double> *aux = LapackConnector::transpose(a, n, lda);
785  // call the fortran routine
786  zheev_(&jobz, &uplo, &n, aux, &lda, w, work, &lwork, rwork, &info);
787  // Transpose the fortran-form real-complex array to the complex matrix.
788  LapackConnector::transpose(aux, a, n, lda);
789  // free the memory.
790  delete[] aux;
791  }
792 
793  // eigenvector of hermitian matrix
794  static inline
795  void heev_f(const char &jobz, const char &uplo, const int &n,
796  std::complex<float> *a, const int &lda, float *w,
797  std::complex<float> *work, const int &lwork, float *rwork, int &info)
798  {
799  cheev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, rwork, &info);
800  }
801 
802  static inline
803  void heev_f(const char &jobz, const char &uplo, const int &n,
804  std::complex<double> *a, const int &lda, double *w,
805  std::complex<double> *work, const int &lwork, double *rwork, int &info)
806  {
807  zheev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, rwork, &info);
808  }
809 
810  // Peize Lin add 2018-06-12
811  // out = ||x||_2
812  // static inline
813  // float nrm2( const int n, const float *X, const int incX )
814  // {
815  // return snrm2_( &n, X, &incX );
816  // }
817  //static inline
818  // double nrm2( const int n, const double *X, const int incX )
819  // {
820  // return dnrm2_( &n, X, &incX );
821  // }
822  // static inline
823  // double nrm2( const int n, const complex<double> *X, const int incX )
824  // {
825  // return dznrm2_( &n, X, &incX );
826  // }
827 
828  static inline
829  void copy(const long n, const double *a, const int incx, double *b, const int incy)
830  {
831  dcopy_(&n, a, &incx, b, &incy);
832  }
833  static inline
834  void copy(const long n, const complex<double> *a, const int incx, complex<double> *b, const int incy)
835  {
836  zcopy_(&n, a, &incx, b, &incy);
837  }
838 
839  // Peize Lin add 2019-04-14
840  // if trans=='N': C = a * A * A.H + b * C
841  // if trans=='C': C = a * A.H * A + b * C
842  static inline
843  void zherk(const char uplo, const char trans, const int n, const int k,
844  const double alpha, const complex<double> *A, const int lda,
845  const double beta, complex<double> *C, const int ldc)
846  {
847  const char uplo_changed = change_uplo(uplo);
848  const char trans_changed = change_trans_NC(trans);
849  zherk_(&uplo_changed, &trans_changed, &n, &k, &alpha, A, &lda, &beta, C, &ldc);
850  }
851 };
852 #endif // LAPACKCONNECTOR_HPP
Definition: complexmatrix.h:20
Definition: lapack_connector.h:18
Definition: matrix.h:23
utilies to handle square matrix and related operations