Merge branch 'devel-pw2w90-pzgeqpf' into 'develop'

SCALAPACK support in routine compute_amn_with_scdm of pw2wannier90

See merge request QEF/q-e!1861
This commit is contained in:
giannozz 2022-06-05 21:05:58 +00:00
commit f3979827a8
4 changed files with 121 additions and 4 deletions

View File

@ -98,6 +98,9 @@ option(QE_LAPACK_INTERNAL
"enable internal reference LAPACK" OFF)
option(QE_ENABLE_SCALAPACK
"enable SCALAPACK execution units" OFF)
cmake_dependent_option(QE_ENABLE_SCALAPACK_QRCP
"enable SCALAPACK QRCP in pw2wannier90 (requires SCALAPACK>=2.1.0 or Intel MKL>=2020)"
OFF "QE_ENABLE_SCALAPACK" OFF)
option(QE_ENABLE_ELPA
"enable ELPA execution units" OFF)
option(QE_ENABLE_LIBXC
@ -484,6 +487,18 @@ if(QE_ENABLE_SCALAPACK)
INTERFACE
${SCALAPACK_LIBRARIES}
${SCALAPACK_LINKER_FLAGS})
if(QE_ENABLE_SCALAPACK_QRCP)
include(CheckFortranFunctionExists)
set(CMAKE_REQUIRED_LIBRARIES "${SCALAPACK_LIBRARIES}")
check_fortran_function_exists("pzgeqpf" SCALAPACK_PZGEQPF_WORKS)
unset(CMAKE_REQUIRED_LIBRARIES)
if(SCALAPACK_PZGEQPF_WORKS)
message(STATUS "Found pzgeqpf, add ScaLAPACK pzgeqpf macro")
qe_add_global_compile_definitions(__SCALAPACK_QRCP)
else()
message(FATAL_ERROR "QE_ENABLE_SCALAPACK_QRCP requested but the current ScaLAPACK installation doesn't contain pzgeqpf!")
endif()
endif()
endif(QE_ENABLE_SCALAPACK)
###########################################################

View File

@ -3641,11 +3641,12 @@ SUBROUTINE compute_amn_with_scdm
USE fft_interfaces, ONLY : invfft !vv: inverse fft transform for computing the unk's on a grid
USE noncollin_module,ONLY : noncolin, npol
USE mp, ONLY : mp_bcast, mp_barrier, mp_sum
USE mp_world, ONLY : world_comm
USE mp_world, ONLY : world_comm, mpime, nproc
USE mp_pools, ONLY : intra_pool_comm
USE cell_base, ONLY : at
USE ions_base, ONLY : ntyp => nsp, tau
USE uspp_param, ONLY : upf
USE mpi, ONLY : MPI_INTEGER !NS: to be removed
IMPLICIT NONE
@ -3665,6 +3666,13 @@ SUBROUTINE compute_amn_with_scdm
CHARACTER (len=60) :: header
LOGICAL :: any_uspp, found_gamma
#if defined(__SCALAPACK_QRCP)
REAL(DP) :: tmp_rwork(2)
INTEGER :: lrwork, context, nprow, npcol, myrow, mycol, descG(9)
INTEGER :: nblocks, rem, nblocks_loc, rem_loc=0, ibl
INTEGER, ALLOCATABLE :: piv_p(:)
#endif
#if defined(__MPI)
INTEGER :: nxxs
COMPLEX(DP),ALLOCATABLE :: psic_all(:)
@ -3709,15 +3717,36 @@ SUBROUTINE compute_amn_with_scdm
info = 0
minmn = MIN(numbands,nrtot)
ALLOCATE(qr_tau(2*minmn))
#if defined(__SCALAPACK_QRCP)
! Dimensions of the process grid
nprow = 1
npcol = nproc
! Initialization of a default BLACS context and the processes grid
call blacs_get( -1, 0, context )
call blacs_gridinit( context, 'Row-major', nprow, npcol )
call blacs_gridinfo( context, nprow, npcol, myrow, mycol )
call descinit(descG, numbands, nrtot, minmn, minmn, 0, 0, context, max(1,minmn), info)
! Global blocks
nblocks = nrtot / minmn
rem = mod(nrtot, minmn)
if (rem > 0) nblocks = nblocks + 1
! Local blocks
nblocks_loc = nblocks / nproc
rem_loc = mod(nblocks, nproc)
if (mpime < rem_loc) nblocks_loc = nblocks_loc + 1
ALLOCATE(piv_p(minmn*nblocks_loc))
piv_p(:) = 0
ALLOCATE(psi_gamma(minmn*nblocks_loc,minmn))
#else
ALLOCATE(piv(nrtot))
piv(:) = 0
ALLOCATE(rwork(2*nrtot))
rwork(:) = 0.0_DP
ALLOCATE(psi_gamma(nrtot,numbands))
#endif
ALLOCATE(kpt_latt(3,iknum))
ALLOCATE(nowfc1(n_wannier,numbands))
ALLOCATE(nowfc(n_wannier,numbands))
ALLOCATE(psi_gamma(nrtot,numbands))
ALLOCATE(focc(numbands))
minmn2 = MIN(numbands,n_wannier)
maxmn2 = MAX(numbands,n_wannier)
@ -3793,10 +3822,19 @@ SUBROUTINE compute_amn_with_scdm
CALL gather_grid(dffts,psic,psic_all)
! vv: Gamma only
! vv: Build Psi_k = Unk * focc
#if defined(__SCALAPACK_QRCP)
CALL mp_bcast(psic_all,ionode_id,world_comm)
norm_psi = sqrt(real(sum(psic_all(1:nrtot)*conjg(psic_all(1:nrtot))),kind=DP))
do ibl=0,nblocks_loc-1
psi_gamma(minmn*ibl+1:minmn*(ibl+1),locibnd) = &
psic_all(minmn*(ibl*nproc+mpime)+1:minmn*(ibl*nproc+mpime+1)) * (f_gamma / norm_psi)
enddo
#else
norm_psi = sqrt(real(sum(psic_all(1:nrtot)*conjg(psic_all(1:nrtot))),kind=DP))
psic_all(1:nrtot) = psic_all(1:nrtot)/ norm_psi
psi_gamma(1:nrtot,locibnd) = psic_all(1:nrtot)
psi_gamma(1:nrtot,locibnd) = psi_gamma(1:nrtot,locibnd) * f_gamma
#endif
#else
norm_psi = sqrt(real(sum(psic(1:nrtot)*conjg(psic(1:nrtot))),kind=DP))
psic(1:nrtot) = psic(1:nrtot)/ norm_psi
@ -3806,6 +3844,33 @@ SUBROUTINE compute_amn_with_scdm
ENDDO
! vv: Perform QR factorization with pivoting on Psi_Gamma
#if defined(__SCALAPACK_QRCP)
WRITE(stdout, '(5x,A,I4,A)') "Running QRCP in parallel, using ", nproc, " cores"
call PZGEQPF( numbands, nrtot, psi_gamma, 1, 1, descG, piv_p, qr_tau, &
tmp_cwork, -1, tmp_rwork, -1, info )
lcwork = AINT(REAL(tmp_cwork(1)))
lrwork = AINT(REAL(tmp_rwork(1)))
ALLOCATE(rwork(lrwork))
ALLOCATE(cwork(lcwork))
rwork(:) = 0.0
cwork(:) = cmplx(0.0,0.0)
call PZGEQPF( numbands, nrtot, TRANSPOSE(CONJG(psi_gamma)), 1, 1, descG, piv_p, qr_tau, &
cwork, lcwork, rwork, lrwork, info )
ALLOCATE(piv(minmn))
if (ionode) piv(1:minmn) = piv_p(1:minmn)
CALL mp_bcast(piv(1:minmn),ionode_id,world_comm)
DEALLOCATE(piv_p)
#else
WRITE(stdout, '(5x, "Running QRCP in serial")')
#if defined(__SCALAPACK)
WRITE(stdout, '(10x, A)') "Program compiled with ScaLAPACK but not using it for QRCP."
WRITE(stdout, '(10x, A)') "To enable ScaLAPACK for QRCP, use valid versions"
WRITE(stdout, '(10x, A)') "(ScaLAPACK >= 2.1.0 or MKL >= 2020) and set the argument"
WRITE(stdout, '(10x, A)') "'with-scalapack_version' in configure."
#endif
! vv: Preliminary call to define optimal values for lwork and cwork size
CALL ZGEQP3(numbands,nrtot,TRANSPOSE(CONJG(psi_gamma)),numbands,piv,qr_tau,tmp_cwork,-1,rwork,info)
IF(info/=0) call errore('compute_amn','Error in computing the QR factorization',1)
@ -3825,6 +3890,7 @@ SUBROUTINE compute_amn_with_scdm
! vv: Perform QR factorization with pivoting on Psi_Gamma
CALL ZGEQP3(numbands,nrtot,TRANSPOSE(CONJG(psi_gamma)),numbands,piv,qr_tau,cwork,lcwork,rwork,info)
IF(info/=0) call errore('compute_amn','Error in computing the QR factorization',1)
#endif
#endif
DEALLOCATE(cwork)
tmp_cwork(:) = (0.0_DP,0.0_DP)
@ -3970,6 +4036,12 @@ SUBROUTINE compute_amn_with_scdm
DEALLOCATE( psic_all )
#endif
#if defined(__SCALAPACK_QRCP)
! Close BLACS environment
call blacs_gridexit( context )
call blacs_exit( 1 )
#endif
IF (ionode .and. wan_mode=='standalone') CLOSE (iun_amn)
WRITE(stdout,'(/)')
WRITE(stdout,*) ' AMN calculated'

14
install/configure vendored
View File

@ -774,6 +774,7 @@ with_libxc
with_libxc_prefix
with_libxc_include
with_scalapack
with_scalapack_qrcp
with_elpa_include
with_elpa_lib
with_elpa_version
@ -1437,6 +1438,8 @@ Optional Packages:
--with-scalapack (yes|no|intel) Use scalapack if available. Set to
"intel" to use Intel MPI and blacs (default: use
openMPI)
--with-scalapack-qrcp (yes|no) Run QRCP with scalapack. Requires ScaLAPACK
>= 2.1.0 or MKL >= 2020. (default: no)
--with-elpa-include Specify full path ELPA include and modules headers
(default: no)
--with-elpa-lib Specify full path ELPA static or dynamic library
@ -7012,6 +7015,17 @@ fi
fi
done
# Enable QRCP with scalapack if --with-scalapack-qrcp==yes is set.
# Requires Scalapack >= 2.1.0 or MKL >= 2020, but the version is not checked.
# If an old version is used, QRCP results might be buggy.
if test "${with_scalapack_qrcp+set}" = set; then :
withval=$with_scalapack_qrcp
if test "$withval" = "yes" && test "$have_scalapack" -eq 1; then
try_dflags="$try_dflags -D__SCALAPACK_QRCP"
fi
fi
# Configuring output message
if test "$have_scalapack" -eq 1; then
scalapack_line="SCALAPACK_LIBS=$scalapack_libs"

View File

@ -89,7 +89,23 @@ fi
try_dflags="$try_dflags -D__SCALAPACK"
fi
done
# Enable QRCP with scalapack if --with-scalapack-qrcp==yes is set.
AC_ARG_WITH(scalapack,
[AS_HELP_STRING([--with-scalapack-qrcp],
[(yes|no) Run QRCP with scalapack. Requires ScaLAPACK >= 2.1.0 or MKL >= 2020. (default: no)],
[if test "$withval" = "yes" ; then
with_scalapack_qrcp=1
else
with_scalapack_qrcp=0
fi],
[with_scalapack_qrcp=0]
)
if test "$have_scalapack" -eq 1 && test "$with_scalapack_qrcp" -eq 1; then
try_dflags="$try_dflags -D__SCALAPACK_QRCP"
fi
# Configuring output message
if test "$have_scalapack" -eq 1; then
scalapack_line="SCALAPACK_LIBS=$scalapack_libs"