1 #ifndef LAPACKCONNECTOR_HPP
2 #define LAPACKCONNECTOR_HPP
9 #include "base_utility.h"
11 #include "complexmatrix.h"
12 #include "interface/blas_lapack.h"
22 static inline T* transpose(
const T* a,
const int &n,
const int &lda,
bool conjugate =
false)
24 T* a_fort =
new T[lda*n];
25 if (is_complex<T>() && conjugate)
26 for (
int i = 0; i < n; i++)
27 for (
int j = 0; j < lda; j++)
28 a_fort[i*lda+j] = get_conj(a[j*n+i]);
30 for (
int i = 0; i < n; i++)
31 for (
int j = 0; j < lda; j++)
32 a_fort[i*lda+j] = a[j*n+i];
37 static inline void transpose(
const T* a, T* a_trans,
const int &n,
const int &lda,
bool conjugate =
false)
39 if (is_complex<T>() && conjugate)
40 for (
int i = 0; i < n; i++)
41 for (
int j = 0; j < lda; j++)
42 a_trans[j*n+i] = get_conj(a[i*lda+j]);
44 for (
int i = 0; i < n; i++)
45 for (
int j = 0; j < lda; j++)
46 a_trans[j*n+i] = a[i*lda+j];
51 complex<double>* transpose(
const ComplexMatrix& a,
const int n,
const int lda)
53 complex<double>* aux =
new complex<double>[lda*n];
54 for (
int i = 0; i < n; ++i)
56 for (
int j = 0; j < lda; ++j)
58 aux[i*lda+j] = a(j,i);
65 double * transpose_matrix(
const matrix& a,
const int n,
const int lda)
67 double * aux =
new double [lda*n];
68 std::cout <<
" lda=" << lda <<
" n=" << n << std::endl;
69 for (
int i=0; i<n; i++)
71 for (
int j=0; j<lda; j++)
73 aux[i*lda+j] = a(j,i);
81 void transpose(
const complex<double>* aux,
ComplexMatrix& a,
const int n,
const int lda)
83 for (
int i = 0; i < n; ++i)
85 for (
int j = 0; j < lda; ++j)
87 a(j, i) = aux[i*lda+j];
94 void transpose_matrix(
const double *aux,
matrix &a,
const int n,
const int lda)
96 for (
int i = 0; i < n; ++i)
98 for (
int j = 0; j < lda; ++j)
100 a(j, i) = aux[i*lda+j];
107 char change_side(
const char &side)
111 case 'L':
return 'R';
112 case 'R':
return 'L';
113 default:
throw invalid_argument(
"Side must be 'L' or 'R'");
119 char change_uplo(
const char &uplo)
123 case 'U':
return 'L';
124 case 'L':
return 'U';
125 default:
throw invalid_argument(
"uplo must be 'U' or 'L'");
131 char change_trans_NT(
const char &trans)
135 case 'N':
return 'T';
136 case 'T':
return 'N';
137 default:
throw invalid_argument(
"trans must be 'N' or 'T'");
142 char change_trans_NC(
const char &trans)
146 case 'N':
return 'C';
147 case 'C':
return 'N';
148 default:
throw invalid_argument(
"trans must be 'N' or 'C'");
155 int ilaenv(
int ispec,
const char *name,
const char *opts,
const int n1,
const int n2,
156 const int n3,
const int n4)
158 const int nb = ilaenv_(&ispec, name, opts, &n1, &n2, &n3, &n4);
163 void zgesv (
const int n,
const int nrhs,
ComplexMatrix &A,
const int lda,
164 const int *ipiv,complex<double> *B,
const int ldb,
int* info)
166 complex<double> *aux = LapackConnector::transpose(A, n, lda);
167 zgesv_(&n, &nrhs, aux, &lda, ipiv, B, &ldb, info);
168 LapackConnector::transpose(aux, A, n, lda);
173 void zhegv(
const int itype,
const char jobz,
const char uplo,
const int n,
ComplexMatrix& a,
174 const int lda,
ComplexMatrix& b,
const int ldb,
double* w,complex<double>* work,
175 int lwork,
double* rwork,
int info)
177 complex<double>* aux = LapackConnector::transpose(a, n, lda);
178 complex<double>* bux = LapackConnector::transpose(b, n, ldb);
181 zhegv_(&itype, &jobz, &uplo, &n, aux, &lda, bux, &ldb, w, work, &lwork, rwork, &info);
183 LapackConnector::transpose(aux, a, n, lda);
184 LapackConnector::transpose(bux, b, n, ldb);
223 void dsygv(
const int itype,
const char jobz,
const char uplo,
const int n,
matrix& a,
224 const int lda,
matrix& b,
const int ldb,
double* w,
double* work,
227 double* aux =
new double[lda*n];
228 double* bux =
new double[lda*n];
229 for (
int i=0; i<n; i++)
231 for (
int j=0; j<lda; j++)
233 aux[i*lda+j] = a(j,i);
234 bux[i*lda+j] = b(j,i);
238 dsygv_(&itype, &jobz, &uplo, &n, aux, &lda, bux, &ldb, w, work, &lwork, info);
239 for (
int i=0; i<n; i++)
241 for(
int j=0; j<lda; ++j)
255 void dsyev(
const char jobz,
const char uplo,
const int n,
double *a,
256 const int lda,
double *w,
double *work,
const int lwork,
int &info)
258 const char uplo_changed = change_uplo(uplo);
259 dsyev_(&jobz, &uplo_changed, &n, a, &lda, w, work, &lwork, &info);
277 void zheev(
const char jobz,
283 complex< double >* work,
288 complex<double> *aux = LapackConnector::transpose(a, n, lda);
290 zheev_(&jobz, &uplo, &n, aux, &lda, w, work, &lwork, rwork, info);
292 LapackConnector::transpose(aux, a, n, lda);
299 void dsytrf(
char uplo,
int n,
matrix& a,
int lda,
int *ipiv,
300 double *work,
int lwork ,
int info)
302 double * aux = LapackConnector::transpose_matrix(a, n, lda) ;
303 dsytrf_(&uplo, &n, aux, &lda, ipiv, work, &lwork, &info);
304 LapackConnector::transpose_matrix(aux, a, n, lda);
309 void dsytri(
char uplo,
int n,
matrix& a,
int lda,
int *iwork,
double * work,
int info)
311 double * aux = LapackConnector::transpose_matrix(a, n, lda) ;
312 dsytri_(&uplo, &n, aux, &lda, iwork, work, &info);
313 LapackConnector::transpose_matrix(aux, a, n, lda);
319 void zhegvx(
const int itype,
const char jobz,
const char range,
const char uplo,
321 const int ldb,
const double vl,
const double vu,
const int il,
const int iu,
322 const double abstol,
const int m,
double* w,
ComplexMatrix& z,
const int ldz,
323 complex<double>* work,
const int lwork,
double* rwork,
int* iwork,
324 int* ifail,
int info)
327 complex<double>* aux = LapackConnector::transpose(a, n, lda);
328 complex<double>* bux = LapackConnector::transpose(b, n, ldb);
329 complex<double>* zux =
new complex<double>[n*iu];
333 zhegvx_(&itype, &jobz, &range, &uplo, &n, aux, &lda, bux, &ldb, &vl,
334 &vu, &il,&iu, &abstol, &m, w, zux, &ldz, work, &lwork, rwork, iwork, ifail, &info);
337 LapackConnector::transpose(zux, z, iu, n);
346 void dgesvd(
const char jobu,
const char jobvt,
const int m,
const int n,
347 matrix &a,
const int lda,
double *s,
matrix &u,
const int ldu,
348 matrix &vt,
const int ldvt,
double *work,
const int lwork,
int info)
352 double *aux = LapackConnector::transpose_matrix(a, n, lda);
353 double *uux = LapackConnector::transpose_matrix(u, m, ldu);
354 double *vtux = LapackConnector::transpose_matrix(vt, n, ldvt);
356 dgesvd_(&jobu, &jobvt , &m, &n, aux, &lda, s, uux, &ldu, vtux, &ldvt, work, &lwork, &info);
358 LapackConnector::transpose_matrix(aux, a, n, lda);
359 LapackConnector::transpose_matrix(uux, u, m, ldu);
360 LapackConnector::transpose_matrix(vtux, vt, n, ldvt);
369 void zpotrf(
char uplo,
int n,
ComplexMatrix &a,
const int lda,
int* info)
371 complex<double> *aux = LapackConnector::transpose(a, n, lda);
372 zpotrf_( &uplo, &n, aux, &lda, info);
373 LapackConnector::transpose(aux, a, n, lda);
379 void zpotri(
char uplo,
int n,
ComplexMatrix &a,
const int lda,
int* info)
381 complex<double> *aux = LapackConnector::transpose(a, n, lda);
382 zpotri_( &uplo, &n, aux, &lda, info);
383 LapackConnector::transpose(aux, a, n, lda);
413 void zgetrf(
int m,
int n,
ComplexMatrix &a,
const int lda,
int *ipiv,
int *info)
415 complex<double> *aux = LapackConnector::transpose(a, n, lda);
416 zgetrf_( &m, &n, aux, &lda, ipiv, info);
417 LapackConnector::transpose(aux, a, n, lda);
422 void zgetri(
int n,
ComplexMatrix &a,
int lda,
int *ipiv, complex<double> * work,
int lwork,
int *info)
424 complex<double> *aux = LapackConnector::transpose(a, n, lda);
425 zgetri_( &n, aux, &lda, ipiv, work, &lwork, info);
426 LapackConnector::transpose(aux, a, n, lda);
431 static inline void getrf(
const int &m,
const int &n,
float *A,
const int &lda,
int *ipiv,
int &info)
433 float *a_fort = transpose(A, n, lda);
434 sgetrf_(&m, &n, a_fort, &lda, ipiv, &info);
435 transpose(a_fort, A, n, lda);
439 static inline void getrf(
const int &m,
const int &n,
double *A,
const int &lda,
int *ipiv,
int &info)
441 double *a_fort = transpose(A, n, lda);
442 dgetrf_(&m, &n, a_fort, &lda, ipiv, &info);
443 transpose(a_fort, A, n, lda);
447 static inline void getrf(
const int &m,
const int &n, std::complex<float> *A,
const int &lda,
int *ipiv,
int &info)
449 std::complex<float> *a_fort = transpose(A, n, lda);
450 cgetrf_(&m, &n, a_fort, &lda, ipiv, &info);
451 transpose(a_fort, A, n, lda);
455 static inline void getrf(
const int &m,
const int &n, std::complex<double> *A,
const int &lda,
int *ipiv,
int &info)
457 std::complex<double> *a_fort = transpose(A, n, lda);
458 zgetrf_(&m, &n, a_fort, &lda, ipiv, &info);
459 transpose(a_fort, A, n, lda);
463 static inline void getrf_f(
const int &m,
const int &n,
float *A,
const int &lda,
int *ipiv,
int &info)
465 sgetrf_(&m, &n, A, &lda, ipiv, &info);
468 static inline void getrf_f(
const int &m,
const int &n,
double *A,
const int &lda,
int *ipiv,
int &info)
470 dgetrf_(&m, &n, A, &lda, ipiv, &info);
473 static inline void getrf_f(
const int &m,
const int &n, std::complex<float> *A,
const int &lda,
int *ipiv,
int &info)
475 cgetrf_(&m, &n, A, &lda, ipiv, &info);
478 static inline void getrf_f(
const int &m,
const int &n, std::complex<double> *A,
const int &lda,
int *ipiv,
int &info)
480 zgetrf_(&m, &n, A, &lda, ipiv, &info);
483 static inline void getri(
const int &n,
float *A,
const int &lda,
int *ipiv,
float *work,
const int &lwork,
int &info)
485 float *a_fort = transpose(A, n, lda);
486 sgetri_(&n, a_fort, &lda, ipiv, work, &lwork, &info);
487 transpose(a_fort, A, n, lda);
491 static inline void getri(
const int &n,
double *A,
const int &lda,
int *ipiv,
double *work,
const int &lwork,
int &info)
493 double *a_fort = transpose(A, n, lda);
494 dgetri_(&n, a_fort, &lda, ipiv, work, &lwork, &info);
495 transpose(a_fort, A, n, lda);
499 static inline void getri(
const int &n, std::complex<float> *A,
const int &lda,
int *ipiv, std::complex<float> *work,
const int &lwork,
int &info)
501 std::complex<float> *a_fort = transpose(A, n, lda);
502 cgetri_(&n, a_fort, &lda, ipiv, work, &lwork, &info);
503 transpose(a_fort, A, n, lda);
507 static inline void getri(
const int &n, std::complex<double> *A,
const int &lda,
int *ipiv, std::complex<double> *work,
const int &lwork,
int &info)
509 std::complex<double> *a_fort = transpose(A, n, lda);
510 zgetri_(&n, a_fort, &lda, ipiv, work, &lwork, &info);
511 transpose(a_fort, A, n, lda);
515 static inline void getri_f(
const int &n,
float *A,
const int &lda,
int *ipiv,
float *work,
const int &lwork,
int &info)
517 sgetri_(&n, A, &lda, ipiv, work, &lwork, &info);
520 static inline void getri_f(
const int &n,
double *A,
const int &lda,
int *ipiv,
double *work,
const int &lwork,
int &info)
522 dgetri_(&n, A, &lda, ipiv, work, &lwork, &info);
525 static inline void getri_f(
const int &n, std::complex<float> *A,
const int &lda,
int *ipiv, std::complex<float> *work,
const int &lwork,
int &info)
527 cgetri_(&n, A, &lda, ipiv, work, &lwork, &info);
530 static inline void getri_f(
const int &n, std::complex<double> *A,
const int &lda,
int *ipiv, std::complex<double> *work,
const int &lwork,
int &info)
532 zgetri_(&n, A, &lda, ipiv, work, &lwork, &info);
537 void dpotrf(
char uplo,
const int n,
matrix &a,
const int lda,
int *info )
539 const char uplo_changed = change_uplo(uplo);
540 dpotrf_( &uplo_changed, &n, a.c, &lda, info );
545 void dpotri(
char uplo,
const int n,
matrix &a,
const int lda,
int *info )
547 const char uplo_changed = change_uplo(uplo);
548 dpotri_( &uplo_changed, &n, a.c, &lda, info);
554 void axpy(
const int n,
const float alpha,
const float *X,
const int incX,
float *Y,
const int incY)
556 saxpy_(&n, &alpha, X, &incX, Y, &incY);
559 void axpy(
const int n,
const double alpha,
const double *X,
const int incX,
double *Y,
const int incY)
561 daxpy_(&n, &alpha, X, &incX, Y, &incY);
564 void axpy(
const int n,
const complex<float> alpha,
const complex<float> *X,
const int incX, complex<float> *Y,
const int incY)
566 caxpy_(&n, &alpha, X, &incX, Y, &incY);
569 void axpy(
const int n,
const complex<double> alpha,
const complex<double> *X,
const int incX, complex<double> *Y,
const int incY)
571 zaxpy_(&n, &alpha, X, &incX, Y, &incY);
577 void scal(
const int n,
const float alpha,
float *X,
const int incX)
579 sscal_(&n, &alpha, X, &incX);
582 void scal(
const int n,
const double alpha,
double *X,
const int incX)
584 dscal_(&n, &alpha, X, &incX);
587 void scal(
const int n,
const complex<float> alpha, complex<float> *X,
const int incX)
589 cscal_(&n, &alpha, X, &incX);
592 void scal(
const int n,
const complex<double> alpha, complex<double> *X,
const int incX)
594 zscal_(&n, &alpha, X, &incX);
600 float dot(
const int n,
const float *X,
const int incX,
const float *Y,
const int incY)
602 return sdot_(&n, X, &incX, Y, &incY);
605 double dot(
const int n,
const double *X,
const int incX,
const double *Y,
const int incY)
607 return ddot_(&n, X, &incX, Y, &incY);
612 std::complex<float> dot(
const int n,
const std::complex<float> *X,
const int incX,
const std::complex<float> *Y,
const int incY)
614 std::complex<float> result;
615 cdotu_(&result, &n, X, &incX, Y, &incY);
620 std::complex<double> dot(
const int n,
const std::complex<double> *X,
const int incX,
const std::complex<double> *Y,
const int incY)
622 std::complex<double> result;
623 zdotu_(&result, &n, X, &incX, Y, &incY);
631 void gemv(
const char transa,
const int m,
const int n,
const float alpha,
const float *a,
632 const int lda,
const float *x,
const int incx,
const float beta,
float *y,
const int incy)
634 char transa_f = change_trans_NT(transa);
635 sgemv_(&transa_f, &n, &m, &alpha, a, &lda, x, &incx, &beta, y, &incy);
639 void gemv(
const char transa,
const int m,
const int n,
const double alpha,
const double *a,
640 const int lda,
const double *x,
const int incx,
const double beta,
double *y,
const int incy)
642 char transa_f = change_trans_NT(transa);
643 dgemv_(&transa_f, &n, &m, &alpha, a, &lda, x, &incx, &beta, y, &incy);
646 inline void gemv_f(
const char &transa,
const int &m,
const int &n,
647 const float &alpha,
const float *a,
const int &lda,
648 const float *x,
const int &incx,
const float &beta,
float *y,
const int &incy)
650 sgemv_(&transa, &m, &n, &alpha, a, &lda, x, &incx, &beta, y, &incy);
653 inline void gemv_f(
const char &transa,
const int &m,
const int &n,
654 const double &alpha,
const double *a,
const int &lda,
655 const double *x,
const int &incx,
const double &beta,
double *y,
const int &incy)
657 dgemv_(&transa, &m, &n, &alpha, a, &lda, x, &incx, &beta, y, &incy);
660 inline void gemv_f(
const char &transa,
const int &m,
const int &n,
661 const std::complex<float> &alpha,
const std::complex<float> *a,
const int &lda,
662 const std::complex<float> *x,
const int &incx,
const std::complex<float> &beta, std::complex<float> *y,
const int &incy)
664 cgemv_(&transa, &m, &n, &alpha, a, &lda, x, &incx, &beta, y, &incy);
667 inline void gemv_f(
const char &transa,
const int &m,
const int &n,
668 const std::complex<double> &alpha,
const std::complex<double> *a,
const int &lda,
669 const std::complex<double> *x,
const int &incx,
const std::complex<double> &beta, std::complex<double> *y,
const int &incy)
671 zgemv_(&transa, &m, &n, &alpha, a, &lda, x, &incx, &beta, y, &incy);
677 void gemm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
678 const float alpha,
const float *a,
const int lda,
const float *b,
const int ldb,
679 const float beta,
float *c,
const int ldc)
681 sgemm_(&transb, &transa, &n, &m, &k,
682 &alpha, b, &ldb, a, &lda,
687 void gemm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
688 const double alpha,
const double *a,
const int lda,
const double *b,
const int ldb,
689 const double beta,
double *c,
const int ldc)
691 dgemm_(&transb, &transa, &n, &m, &k,
692 &alpha, b, &ldb, a, &lda,
697 void gemm(
const char transa,
const char transb,
const int m,
698 const int n,
const int k,
const complex<float> alpha,
699 const complex<float> *a,
const int lda,
700 const complex<float> *b,
const int ldb,
701 const complex<float> beta, complex<float> *c,
const int ldc)
703 cgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda,
708 void gemm(
const char transa,
const char transb,
const int m,
709 const int n,
const int k,
const complex<double> alpha,
710 const complex<double> *a,
const int lda,
711 const complex<double> *b,
const int ldb,
712 const complex<double> beta, complex<double> *c,
const int ldc)
714 zgemm_(&transb, &transa, &n, &m, &k,
715 &alpha, b, &ldb, a, &lda,
720 void gemm_f(
const char transa,
const char transb,
721 const int m,
const int n,
const int k,
722 const float alpha,
const float *a,
723 const int lda,
const float *b,
const int ldb,
724 const float beta,
float *c,
const int ldc)
726 sgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb,
730 static inline void gemm_f(
const char transa,
const char transb,
731 const int m,
const int n,
const int k,
732 const complex<float> alpha,
733 const complex<float> *a,
const int lda,
734 const complex<float> *b,
const int ldb,
735 const complex<float> beta,
736 complex<float> *c,
const int ldc)
738 cgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c,
742 static inline void gemm_f(
const char transa,
const char transb,
743 const int m,
const int n,
const int k,
744 const double alpha,
const double *a,
745 const int lda,
const double *b,
const int ldb,
746 const double beta,
double *c,
const int ldc)
748 dgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb,
752 static inline void gemm_f(
const char transa,
const char transb,
753 const int m,
const int n,
const int k,
754 const complex<double> alpha,
755 const complex<double> *a,
const int lda,
756 const complex<double> *b,
const int ldb,
757 const complex<double> beta,
758 complex<double> *c,
const int ldc)
760 zgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c,
766 void heev(
const char &jobz,
const char &uplo,
const int &n,
767 std::complex<float> *a,
const int &lda,
float *w,
768 std::complex<float> *work,
const int &lwork,
float *rwork,
int &info)
770 complex<float> *aux = LapackConnector::transpose(a, n, lda);
772 cheev_(&jobz, &uplo, &n, aux, &lda, w, work, &lwork, rwork, &info);
774 LapackConnector::transpose(aux, a, n, lda);
780 void heev(
const char &jobz,
const char &uplo,
const int &n,
781 std::complex<double> *a,
const int &lda,
double *w,
782 std::complex<double> *work,
const int &lwork,
double *rwork,
int &info)
784 complex<double> *aux = LapackConnector::transpose(a, n, lda);
786 zheev_(&jobz, &uplo, &n, aux, &lda, w, work, &lwork, rwork, &info);
788 LapackConnector::transpose(aux, a, n, lda);
795 void heev_f(
const char &jobz,
const char &uplo,
const int &n,
796 std::complex<float> *a,
const int &lda,
float *w,
797 std::complex<float> *work,
const int &lwork,
float *rwork,
int &info)
799 cheev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, rwork, &info);
803 void heev_f(
const char &jobz,
const char &uplo,
const int &n,
804 std::complex<double> *a,
const int &lda,
double *w,
805 std::complex<double> *work,
const int &lwork,
double *rwork,
int &info)
807 zheev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, rwork, &info);
829 void copy(
const long n,
const double *a,
const int incx,
double *b,
const int incy)
831 dcopy_(&n, a, &incx, b, &incy);
834 void copy(
const long n,
const complex<double> *a,
const int incx, complex<double> *b,
const int incy)
836 zcopy_(&n, a, &incx, b, &incy);
843 void zherk(
const char uplo,
const char trans,
const int n,
const int k,
844 const double alpha,
const complex<double> *A,
const int lda,
845 const double beta, complex<double> *C,
const int ldc)
847 const char uplo_changed = change_uplo(uplo);
848 const char trans_changed = change_trans_NC(trans);
849 zherk_(&uplo_changed, &trans_changed, &n, &k, &alpha, A, &lda, &beta, C, &ldc);
Definition: complexmatrix.h:20
Definition: lapack_connector.h:18
utilies to handle square matrix and related operations