diff --git a/PW/rdiaghg.f90 b/PW/rdiaghg.f90 index bb630a425..563d1453e 100644 --- a/PW/rdiaghg.f90 +++ b/PW/rdiaghg.f90 @@ -181,7 +181,11 @@ SUBROUTINE prdiaghg( n, h, s, ldh, e, v, desc ) USE mp_global, ONLY : root_pool, intra_pool_comm USE dspev_module, ONLY : pdspev_drv USE descriptors, ONLY : descla_siz_ , lambda_node_ , nlax_ , la_nrl_ , & - la_npc_ , la_npr_ , la_me_ , la_comm_ , la_nrlx_ + la_npc_ , la_npr_ , la_me_ , la_comm_ , la_nrlx_ , & + nlar_ , la_myc_ , la_myr_ +#if defined __SCALAPACK + USE mp_global, ONLY : ortho_cntx, me_blacs, np_ortho, me_ortho +#endif ! ! IMPLICIT NONE @@ -206,6 +210,9 @@ SUBROUTINE prdiaghg( n, h, s, ldh, e, v, desc ) REAL(DP), ALLOCATABLE :: diag(:,:), vv(:,:) REAL(DP), ALLOCATABLE :: hh(:,:) REAL(DP), ALLOCATABLE :: ss(:,:) +#ifdef __SCALAPACK + INTEGER :: desch( 16 ), info +#endif ! CALL start_clock( 'rdiaghg' ) ! @@ -232,7 +239,18 @@ SUBROUTINE prdiaghg( n, h, s, ldh, e, v, desc ) ! IF( desc( lambda_node_ ) > 0 ) THEN ! +#ifdef __SCALAPACK + CALL descinit( desch, n, n, desc( nlax_ ), desc( nlax_ ), 0, 0, ortho_cntx, SIZE( hh, 1 ) , info ) + + IF( info /= 0 ) CALL errore( ' cdiaghg ', ' descinit ', ABS( info ) ) +#endif + ! +#ifdef __SCALAPACK + CALL PDPOTRF( 'L', n, ss, 1, 1, desch, info ) + IF( info /= 0 ) CALL errore( ' rdiaghg ', ' problems computing cholesky ', ABS( info ) ) +#else CALL qe_pdpotrf( ss, nx, n, desc ) +#endif ! END IF ! @@ -244,7 +262,15 @@ SUBROUTINE prdiaghg( n, h, s, ldh, e, v, desc ) ! IF( desc( lambda_node_ ) > 0 ) THEN ! +#ifdef __SCALAPACK + CALL clear_upper_tr( ss ) + + CALL PDTRTRI( 'L', 'N', n, ss, 1, 1, desch, info ) + ! + IF( info /= 0 ) CALL errore( ' rdiaghg ', ' problems computing inverse ', ABS( info ) ) +#else CALL qe_pdtrtri ( ss, nx, n, desc ) +#endif ! END IF ! @@ -317,25 +343,33 @@ SUBROUTINE prdiaghg( n, h, s, ldh, e, v, desc ) #ifdef __SCALAPACK ! CONTAINS + + SUBROUTINE clear_upper_tr( mat ) + REAL(DP) :: mat( :, : ) + INTEGER :: i, j + IF( desc( la_myc_ ) > desc( la_myr_ ) ) mat = 0.0d0 + IF( desc( la_myc_ ) == desc( la_myr_ ) ) THEN + DO j = 1, desc( nlar_ ) + DO i = 1, j - 1 + mat( i, j ) = 0.0d0 + END DO + END DO + END IF + RETURN + END SUBROUTINE clear_upper_tr + ! SUBROUTINE scalapack_drv() - USE mp_global, ONLY : ortho_cntx, me_blacs, np_ortho, me_ortho - - INTEGER :: desch( 10 ) - REAL(DP) :: rtmp( 1 ) - INTEGER :: itmp( 1 ) + REAL(DP) :: rtmp( 4 ) + INTEGER :: itmp( 4 ) REAL(DP), ALLOCATABLE :: work(:) REAL(DP), ALLOCATABLE :: vv(:,:) INTEGER, ALLOCATABLE :: iwork(:) - INTEGER :: LWORK, LIWORK, info + INTEGER :: LWORK, LIWORK ! ALLOCATE( vv( SIZE( hh, 1 ), SIZE( hh, 2 ) ) ) - CALL descinit( desch, n, n, desc( nlax_ ), desc( nlax_ ), 0, 0, ortho_cntx, SIZE( hh, 1 ) , info ) - - IF( info /= 0 ) CALL errore( ' cdiaghg ', ' desckinit ', ABS( info ) ) - lwork = -1 liwork = 1