fft_wave_wrap - a few small fixes and some doc

This commit is contained in:
fabrizio22 2022-10-19 13:11:02 +02:00
parent 13c65c960d
commit aff40f7b17
7 changed files with 56 additions and 22 deletions

View File

@ -163,7 +163,7 @@ SUBROUTINE exx_psi(c, psitot2,nnrtot,my_nbsp, my_nxyz, nbsp)
rdispls1(proc) = rdispls1(proc-1) + recvcount1(proc-1)
END DO
!
ALLOCATE ( psis(dffts%nnr*nogrp) ); psis=0.0_DP
ALLOCATE ( psis(dffts%nnr*nogrp) ); psis=0.0_DP
ALLOCATE ( psis1(dffts%nnr*nogrp) ); psis1=0.0_DP
ALLOCATE ( psis2(dffts%nnr,nproc_image/nogrp)); psis2=0.0_DP
!
@ -176,7 +176,7 @@ SUBROUTINE exx_psi(c, psitot2,nnrtot,my_nbsp, my_nxyz, nbsp)
!
DO i = 1, nbsp, 2*nogrp
!
CALL fftx_c2psi_gamma_tg( dffts, psis, c, ngw, i, nbsp )
CALL fftx_c2psi_gamma_tg( dffts, psis, c(:,i:nbsp), ngw, nbsp-i+1 )
!
CALL invfft( 'tgWave', psis, dffts )
!

View File

@ -858,7 +858,7 @@ https://www.tddft.org/programs/libxc/functionals/
Combinations of \qe\ and \libxc\ functionals are allowed in \texttt{PW}, but some attention has to be paid to their reciprocal compatibility (see section below).\\
For example, the internal exchange term of PBE together with the correlation term of PBE in \libxc\ is obtained by:
\begin{verbatim}
input_dft = `XC-001I-000I-003L-130L-000I-000I'
input_dft = `XC-001I-000I-003I-130L-000I-000I'
\end{verbatim}
which corresponds to the old:
\begin{verbatim}

View File

@ -459,13 +459,17 @@ CONTAINS
IMPLICIT NONE
!
TYPE(fft_type_descriptor), INTENT(IN) :: desc
!! fft descriptor
!! FFT descriptor
COMPLEX(DP), INTENT(OUT) :: psi(:)
!! w.f. 3D array in Fourier space
COMPLEX(DP), INTENT(IN) :: c(:,:)
!! stores the Fourier expansion coefficients of the wave function
INTEGER, INTENT(IN) :: igk(:), ngk
INTEGER, INTENT(IN) :: igk(:)
!! index of G corresponding to a given index of k+G
INTEGER, INTENT(IN) :: ngk
!! size of c(:,1) or
INTEGER, OPTIONAL, INTENT(IN) :: howmany
!!
!
INTEGER :: nnr, i, j, ig
!

View File

@ -10,7 +10,8 @@
!
MODULE fft_wave
!
!! This module contains wrapper to FFT and inverse FFTs of w.f.
!! This module contains wrappers to FFT and inverse FFTs of the wave function,
!! which it enclose the calls to g-vect/FFT-grid transposition routines too.
!
USE kinds, ONLY: DP
USE fft_interfaces, ONLY: fwfft, invfft
@ -37,11 +38,20 @@ CONTAINS
IMPLICIT NONE
!
TYPE(fft_type_descriptor), INTENT(IN) :: dfft
COMPLEX(DP), INTENT(INOUT) :: f_in(:)
!! FFT descriptor
COMPLEX(DP) :: f_in(:)
!! input: r-space wave-function
COMPLEX(DP), INTENT(OUT) :: f_out(:,:)
!! output: g-space wave-function
INTEGER, OPTIONAL, INTENT(IN) :: igk(:)
INTEGER, OPTIONAL, INTENT(IN) :: howmany_set(2)
!! index of G corresponding to a given index of k+G
INTEGER, OPTIONAL, INTENT(IN) :: howmany_set(3)
!! gpu-enabled only (many_fft>1 case):
!! (1) group_size;
!! (2) true dimension of psi;
!! (3) howmany (if gamma) or group_size (if k).
!
! ... local variables
INTEGER :: dim1, dim2
!
dim1 = SIZE(f_in(:))
@ -51,7 +61,7 @@ CONTAINS
!
!$acc host_data use_device(f_in)
IF (PRESENT(howmany_set)) THEN
CALL fwfft( 'Wave', f_in, dfft, howmany=howmany_set(1) )
CALL fwfft( 'Wave', f_in, dfft, howmany=howmany_set(3) )
ELSE
CALL fwfft( 'Wave', f_in, dfft )
ENDIF
@ -59,7 +69,7 @@ CONTAINS
!
IF (gamma_only) THEN
IF (PRESENT(howmany_set)) THEN
CALL fftx_psi2c_gamma( dfft, f_in, f_out, howmany_set=howmany_set )
CALL fftx_psi2c_gamma( dfft, f_in, f_out, howmany_set=howmany_set(1:2) )
ELSE
IF (dim2==1) CALL fftx_psi2c_gamma( dfft, f_in, f_out(:,1:1) )
IF (dim2==2) CALL fftx_psi2c_gamma( dfft, f_in, f_out(:,1:1), &
@ -68,7 +78,7 @@ CONTAINS
ELSE
!$acc data present_or_copyin(igk)
IF (PRESENT(howmany_set)) THEN
CALL fftx_psi2c_k( dfft, f_in, f_out, igk, howmany_set )
CALL fftx_psi2c_k( dfft, f_in, f_out, igk, howmany_set(1:2) )
ELSE
CALL fftx_psi2c_k( dfft, f_in, f_out(:,1:1), igk )
ENDIF
@ -92,11 +102,20 @@ CONTAINS
IMPLICIT NONE
!
TYPE(fft_type_descriptor), INTENT(IN) :: dfft
!! FFT wave descriptor
COMPLEX(DP), INTENT(IN) :: f_in(:,:)
!! input: g-space wave-function
COMPLEX(DP) :: f_out(:)
!! output: r-space wave-function
INTEGER, OPTIONAL, INTENT(IN) :: igk(:)
INTEGER, OPTIONAL, INTENT(IN) :: howmany_set(2)
!! index of G corresponding to a given index of k+G
INTEGER, OPTIONAL, INTENT(IN) :: howmany_set(3)
!! gpu-enabled only (many_fft>1 case):
!! (1) group_size;
!! (2) true dimension of psi;
!! (3) howmany (if gamma) or group_size (if k).
!
! ... local variables
INTEGER :: npw, dim2
!
!$acc data present_or_copyin(f_in) present_or_copyout(f_out)
@ -107,7 +126,7 @@ CONTAINS
IF (gamma_only) THEN
!
IF (PRESENT(howmany_set)) THEN
CALL fftx_c2psi_gamma( dfft, f_out, f_in, howmany_set=howmany_set )
CALL fftx_c2psi_gamma( dfft, f_out, f_in, howmany_set=howmany_set(1:2) )
ELSE
IF (dim2/=2) CALL fftx_c2psi_gamma( dfft, f_out, f_in(:,1:1) )
IF (dim2==2) CALL fftx_c2psi_gamma( dfft, f_out, f_in(:,1:1), ca=f_in(:,2) )
@ -128,7 +147,7 @@ CONTAINS
!
!$acc host_data use_device( f_out )
IF (PRESENT(howmany_set)) THEN
CALL invfft( 'Wave', f_out, dfft, howmany=howmany_set(1) )
CALL invfft( 'Wave', f_out, dfft, howmany=howmany_set(3) )
ELSE
CALL invfft( 'Wave', f_out, dfft )
ENDIF
@ -151,11 +170,17 @@ CONTAINS
IMPLICIT NONE
!
TYPE(fft_type_descriptor), INTENT(IN) :: dfft
!! FFT descriptor
COMPLEX(DP) :: f_in(:,:)
!! input: wave in g-space - task group chunk
COMPLEX(DP) :: f_out(:)
!! output: wave in r-space - task group chunk
INTEGER, INTENT(IN) :: n
!! true dimension of f_in
INTEGER, OPTIONAL, INTENT(IN) :: igk(:)
!! index of G corresponding to a given index of k+G
!
! ... local variables
INTEGER :: npw, dbnd
!
!$acc data present_or_copyin(f_in,igk) present_or_copyout(f_out)
@ -195,11 +220,17 @@ CONTAINS
IMPLICIT NONE
!
TYPE(fft_type_descriptor), INTENT(IN) :: dfft
COMPLEX(DP), INTENT(INOUT) :: f_in(:)
!! FFT descriptor
COMPLEX(DP) :: f_in(:)
!! input: wave in g-space - task group chunk
COMPLEX(DP), INTENT(OUT) :: f_out(:,:)
!! output: wave in r-space - task group chunk
INTEGER, INTENT(IN) :: n
!! true dimension of f_out
INTEGER, OPTIONAL, INTENT(IN) :: igk(:)
!! index of G corresponding to a given index of k+G
!
! ... local variables
INTEGER :: dbnd
!
dbnd = SIZE(f_out(1,:))

View File

@ -18,7 +18,6 @@ SUBROUTINE sum_band()
USE cell_base, ONLY : at, bg, omega, tpiba
USE ions_base, ONLY : nat, ntyp => nsp, ityp
USE fft_base, ONLY : dfftp, dffts
USE fft_interfaces, ONLY : invfft
USE fft_rho, ONLY : rho_g2r, rho_r2g
USE fft_wave, ONLY : wave_g2r, tgwave_g2r
USE gvect, ONLY : ngm, g

View File

@ -730,7 +730,7 @@ SUBROUTINE sum_band_gpu()
ELSEIF (many_fft > 1 .AND. (.NOT. (xclib_dft_is('meta') .OR. lxdm))) THEN
!
group_size = MIN(many_fft,ibnd_end-(ibnd-1))
hm_vec(1)=group_size ; hm_vec(2)=npw
hm_vec(1)=group_size ; hm_vec(2)=npw ; hm_vec(3)=group_size
!
CALL wave_g2r( evc(:,ibnd:ibnd+group_size-1), psicd, &
dffts, igk=igk_k(:,ik), howmany_set=hm_vec )

View File

@ -51,7 +51,7 @@ SUBROUTINE vloc_psi_gamma_gpu( lda, n, m, psi_d, v_d, hpsi_d )
INTEGER :: v_siz, idx, ebnd, brange
INTEGER :: ierr, ioff
! ... Variables to handle batched FFT
INTEGER :: group_size, pack_size, remainder, howmany, hm_vec(2)
INTEGER :: group_size, pack_size, remainder, howmany, hm_vec(3)
REAL(DP):: fac
!
CALL start_clock_gpu( 'vloc_psi' )
@ -117,7 +117,7 @@ SUBROUTINE vloc_psi_gamma_gpu( lda, n, m, psi_d, v_d, hpsi_d )
ENDIF
ENDDO
!
ENDDO
ENDDO
!$acc end data
!
ELSEIF (many_fft > 1) THEN
@ -129,7 +129,7 @@ SUBROUTINE vloc_psi_gamma_gpu( lda, n, m, psi_d, v_d, hpsi_d )
pack_size = (group_size/2) ! This is FLOOR(group_size/2)
remainder = group_size - 2*pack_size
howmany = pack_size + remainder
hm_vec(1)=group_size ; hm_vec(2)=n
hm_vec(1)=group_size ; hm_vec(2)=n ; hm_vec(3)=howmany
!
CALL wave_g2r( psi(:,ibnd:ibnd+group_size-1), psic, dffts, howmany_set=hm_vec )
!
@ -260,7 +260,7 @@ SUBROUTINE vloc_psi_k_gpu( lda, n, m, psi_d, v_d, hpsi_d )
#if defined(__CUDA)
attributes(DEVICE) :: tg_v_d
!
INTEGER :: v_siz, idx, group_size, hm_vec(2)
INTEGER :: v_siz, idx, group_size, hm_vec(3)
INTEGER :: ierr, brange
!
CALL start_clock_gpu ('vloc_psi')
@ -325,7 +325,7 @@ SUBROUTINE vloc_psi_k_gpu( lda, n, m, psi_d, v_d, hpsi_d )
DO ibnd = 1, m, incr
!
group_size = MIN(many_fft,m-(ibnd-1))
hm_vec(1)=group_size ; hm_vec(2)=n
hm_vec(1)=group_size ; hm_vec(2)=n ; hm_vec(3)=group_size
ebnd = ibnd+group_size-1
!
CALL wave_g2r( psi(:,ibnd:ebnd), psic, dffts, igk=igk_k(:,current_k), &