fft_wave_wrap - r2g calls in vloc_psi gamma & gamma tg fix

This commit is contained in:
fabrizio22 2022-09-08 10:05:03 +02:00
parent b83bbb79bf
commit 16fe41f61d
3 changed files with 68 additions and 76 deletions

View File

@ -684,15 +684,18 @@ CONTAINS
!
TYPE(fft_type_descriptor), INTENT(IN) :: desc
COMPLEX(DP), INTENT(OUT) :: psis(:)
COMPLEX(DP), INTENT(INOUT) :: c_bgrp(:,:)
COMPLEX(DP), INTENT(IN) :: c_bgrp(:,:)
INTEGER, INTENT(IN) :: i, nbsp_bgrp
!
INTEGER :: eig_offset, eig_index, right_nnr
INTEGER :: eig_offset, eig_index, right_nnr, ig, ib, ieg
COMPLEX(DP), PARAMETER :: ci=(0.0d0,1.0d0)
!
! the i-th column of c_bgrp corresponds to the i-th state (in this band group)
! ... the i-th column of c_bgrp corresponds to the i-th state (in this band group)
!
! The outer loop goes through i : i + 2*NOGRP to cover
! 2*NOGRP eigenstates at each iteration
! ... The outer loop goes through i : i + 2*NOGRP to cover
! ... 2*NOGRP eigenstates at each iteration
!
CALL alloc_nl_pntrs( desc )
!
right_nnr = desc%nnr
!
@ -708,32 +711,35 @@ CONTAINS
#if !defined(_OPENACC)
!$omp task default(none) &
!$omp firstprivate( eig_index, i, nbsp_bgrp, right_nnr ) &
!$omp private( eig_offset ) &
!$omp shared( c_bgrp, desc, psis )
!$omp private( eig_offset, ib, ieg, ig ) &
!$omp shared( c_bgrp, desc, psis, nl_d, nlm_d )
#endif
!
! here we pack 2*nogrp electronic states in the psis array
! note that if nogrp == nproc_bgrp each proc perform a full 3D
! fft and the scatter phase is local (without communication)
! ... here we pack 2*nogrp electronic states in the psis array
! ... note that if nogrp == nproc_bgrp each proc perform a full 3D
! ... fft and the scatter phase is local (without communication)
!
! important: if n is odd => c(*,n+1)=0.
!
IF ( (eig_index+i-1)==nbsp_bgrp ) THEN
!$acc kernels
c_bgrp(:,eig_index+i) = (0._DP,0._DP)
!$acc end kernels
ENDIF
! ... important: if n is odd => c(*,n+1)=0.
!
eig_offset = (eig_index-1)/2
!
IF ( (i+eig_index-1) <= nbsp_bgrp ) THEN
!
! The eig_index loop is executed only ONCE when NOGRP=1.
!
CALL c2psi_gamma( desc, psis(eig_offset*right_nnr+1: &
eig_offset*right_nnr+ right_nnr), &
c_bgrp(:,i+eig_index-1), c_bgrp(:,i+eig_index) )
!
ib = eig_offset*right_nnr
ieg = i+eig_index-1
!
! ... The eig_index loop is executed only ONCE when NOGRP=1.
IF ( ieg < nbsp_bgrp ) THEN
!$acc parallel loop
DO ig = 1, desc%ngw
psis(ib+nlm_d(ig)) = CONJG(c_bgrp(ig,ieg)) + ci * CONJG(c_bgrp(ig,ieg+1))
psis(ib+nl_d(ig)) = c_bgrp(ig,ieg) + ci * c_bgrp(ig,ieg+1)
ENDDO
ELSEIF ( ieg == nbsp_bgrp ) THEN
! ... important: if n is odd => c(*,n+1)=0.
!$acc parallel loop
DO ig = 1, desc%ngw
psis(ib+nlm_d(ig)) = CONJG(c_bgrp(ig,ieg))
psis(ib+nl_d(ig)) = c_bgrp(ig,ieg)
ENDDO
ENDIF
#if !defined(_OPENACC)
!$omp end task
@ -746,6 +752,8 @@ CONTAINS
!$omp end single
!$omp end parallel
#endif
!
CALL dealloc_nl_pntrs( desc )
!
RETURN
!

View File

@ -21,7 +21,7 @@ MODULE fft_wave
!
PRIVATE
!
PUBLIC :: wave_r2g, wave_g2r, tgwave_g2r, tgwave_r2g
PUBLIC :: wave_r2g, wave_g2r, tgwave_r2g, tgwave_g2r
!
CONTAINS
!
@ -75,7 +75,7 @@ CONTAINS
IMPLICIT NONE
!
TYPE(fft_type_descriptor), INTENT(IN) :: dfft
COMPLEX(DP) :: f_in(:,:)
COMPLEX(DP), INTENT(IN) :: f_in(:,:)
COMPLEX(DP) :: f_out(:)
INTEGER, OPTIONAL, INTENT(IN) :: igk(:)
INTEGER, OPTIONAL, INTENT(IN) :: howmany_set(3)

View File

@ -15,11 +15,10 @@ SUBROUTINE vloc_psi_gamma( lda, n, m, psi, v, hpsi )
USE kinds, ONLY : DP
USE mp_bands, ONLY : me_bgrp
USE fft_base, ONLY : dffts
USE fft_wave, ONLY : wave_g2r, tgwave_g2r
USE fft_interfaces, ONLY : fwfft
USE fft_wave
USE wavefunctions, ONLY : psic
USE fft_helper_subroutines, ONLY : fftx_ntgrp, tg_get_nnr, &
tg_get_group_nr3, tg_get_recip_inc
USE fft_helper_subroutines, ONLY : fftx_ntgrp, tg_get_group_nr3, &
tg_get_recip_inc
!
IMPLICIT NONE
!
@ -38,15 +37,16 @@ SUBROUTINE vloc_psi_gamma( lda, n, m, psi, v, hpsi )
!
! ... local variables
!
INTEGER :: ibnd, j, incr
INTEGER :: right_nnr, right_nr3, right_inc
INTEGER :: ibnd, j, incr, right_nr3, right_inc
COMPLEX(DP) :: fp, fm
COMPLEX(DP), ALLOCATABLE :: psi2(:,:)
!
!Variables for task groups
LOGICAL :: use_tg
INTEGER :: v_siz, idx, ioff, ebnd, brange
REAL(DP) :: fac
REAL(DP), ALLOCATABLE :: tg_v(:)
COMPLEX(DP), ALLOCATABLE :: tg_psic(:)
INTEGER :: v_siz, idx, ioff, ebnd
!
CALL start_clock( 'vloc_psi' )
incr = 2
@ -54,7 +54,6 @@ SUBROUTINE vloc_psi_gamma( lda, n, m, psi, v, hpsi )
use_tg = dffts%has_task_groups
!
IF( use_tg ) THEN
!
CALL start_clock( 'vloc_psi:tg_gather' )
v_siz = dffts%nnr_tg
!
@ -66,6 +65,9 @@ SUBROUTINE vloc_psi_gamma( lda, n, m, psi, v, hpsi )
!
incr = 2 * fftx_ntgrp(dffts)
!
ALLOCATE( psi2(n,m) )
ELSE
ALLOCATE( psi2(n,2) )
ENDIF
!
! ... the local potential V_Loc psi. First bring psi to real space
@ -75,8 +77,6 @@ SUBROUTINE vloc_psi_gamma( lda, n, m, psi, v, hpsi )
! ... FFT to real space
!
IF ( use_tg ) THEN
!
CALL tg_get_nnr( dffts, right_nnr )
!
CALL tgwave_g2r( psi(1:n,:), tg_psic, dffts, ibnd, m )
!
@ -96,7 +96,7 @@ SUBROUTINE vloc_psi_gamma( lda, n, m, psi, v, hpsi )
CALL tg_get_group_nr3( dffts, right_nr3 )
!
DO j = 1, dffts%nr1x * dffts%nr2x * right_nr3
tg_psic (j) = tg_psic (j) * tg_v(j)
tg_psic(j) = tg_psic(j) * tg_v(j)
ENDDO
!
ELSE
@ -112,68 +112,52 @@ SUBROUTINE vloc_psi_gamma( lda, n, m, psi, v, hpsi )
!
IF( use_tg ) THEN
!
CALL fwfft( 'tgWave', tg_psic, dffts )
fac=1.d0
IF ( idx+ibnd-1<m ) fac=0.5d0
!
ioff = 0
!
CALL tg_get_recip_inc( dffts, right_inc )
CALL tgwave_r2g( tg_psic, psi2, dffts, n, ibnd, m )
!
DO idx = 1, 2*fftx_ntgrp(dffts), 2
!
IF( idx+ibnd-1<m ) THEN
IF ( idx+ibnd-1<m ) THEN
DO j = 1, n
fp= ( tg_psic( dffts%nl(j) + ioff ) + &
tg_psic( dffts%nlm(j) + ioff ) ) * 0.5d0
fm= ( tg_psic( dffts%nl(j) + ioff ) - &
tg_psic( dffts%nlm(j) + ioff ) ) * 0.5d0
hpsi(j,ibnd+idx-1) = hpsi(j,ibnd+idx-1) + &
cmplx( dble(fp), aimag(fm),kind=DP)
hpsi(j,ibnd+idx) = hpsi(j,ibnd+idx) + &
cmplx(aimag(fp),-dble(fm),kind=DP)
hpsi(j,ibnd+idx-1) = hpsi(j,ibnd+idx-1) + fac * psi2(j,ibnd+idx-1)
hpsi(j,ibnd+idx) = hpsi(j,ibnd+idx) + fac * psi2(j,ibnd+idx)
ENDDO
ELSEIF( idx + ibnd - 1 == m ) THEN
ELSEIF ( idx+ibnd-1==m ) THEN
DO j = 1, n
hpsi (j, ibnd+idx-1) = hpsi (j, ibnd+idx-1) + &
tg_psic( dffts%nl(j) + ioff )
hpsi(j,ibnd+idx-1) = hpsi(j,ibnd+idx-1) + fac * psi2(j,ibnd+idx-1)
ENDDO
ENDIF
!
ioff = ioff + right_inc
!
ENDDO
!
ELSE
!
CALL fwfft( 'Wave', psic, dffts )
!
IF (ibnd < m) THEN
! two ffts at the same time
DO j = 1, n
fp = (psic (dffts%nl(j)) + psic (dffts%nlm(j)))*0.5d0
fm = (psic (dffts%nl(j)) - psic (dffts%nlm(j)))*0.5d0
hpsi (j, ibnd) = hpsi (j, ibnd) + &
cmplx( dble(fp), aimag(fm),kind=DP)
hpsi (j, ibnd+1) = hpsi (j, ibnd+1) + &
cmplx(aimag(fp),- dble(fm),kind=DP)
ENDDO
ELSE
DO j = 1, n
hpsi (j, ibnd) = hpsi (j, ibnd) + psic (dffts%nl(j))
ENDDO
brange=1 ; fac=1.d0
IF ( ibnd<m ) THEN
brange=2 ; fac=0.5d0
ENDIF
!
CALL wave_r2g( psic, psi2(:,1:brange), dffts )
!
DO j = 1, n
hpsi(j,ibnd) = hpsi(j,ibnd) + fac*psi2(j,1)
IF ( ibnd<m ) hpsi(j,ibnd+1) = hpsi(j,ibnd+1) + fac*psi2(j,2)
ENDDO
!
ENDIF
!
ENDDO
!
IF( use_tg ) THEN
!
DEALLOCATE( tg_psic )
DEALLOCATE( tg_v )
!
ENDIF
DEALLOCATE( psi2 )
!
CALL stop_clock ('vloc_psi')
!
RETURN
!
END SUBROUTINE vloc_psi_gamma
!
!-----------------------------------------------------------------------