From 5ff7a010763969bf74b60894aa2ef54df514eb90 Mon Sep 17 00:00:00 2001 From: fabrizio22 Date: Tue, 30 Aug 2022 18:08:23 +0200 Subject: [PATCH] fft_wave_wrap - g2r wrapper & meta calls --- FFTXlib/src/fft_helper_subroutines.f90 | 29 +++++++++-- Modules/fft_wave.f90 | 46 ++++++++++++++++- PW/src/h_psi_meta.f90 | 69 +++++++++++--------------- 3 files changed, 98 insertions(+), 46 deletions(-) diff --git a/FFTXlib/src/fft_helper_subroutines.f90 b/FFTXlib/src/fft_helper_subroutines.f90 index 57c5486d3..7688fdc79 100644 --- a/FFTXlib/src/fft_helper_subroutines.f90 +++ b/FFTXlib/src/fft_helper_subroutines.f90 @@ -32,7 +32,7 @@ MODULE fft_helper_subroutines tg_get_group_nr3 ! ... Used only in CP PUBLIC :: fftx_add_threed2oned_gamma, fftx_psi2c_gamma, c2psi_gamma, & - fftx_add_field, c2psi_gamma_tg, c2psi_k, c2psi_k_tg + fftx_add_field, c2psi_gamma_tg, c2psi_k, c2psi_k_tg, fftx_psi2c_k PUBLIC :: fft_dist_info ! ... Used only in CP+EXX PUBLIC :: fftx_tgcomm @@ -597,9 +597,9 @@ CONTAINS IMPLICIT NONE ! TYPE(fft_type_descriptor), INTENT(in) :: desc - complex(DP), INTENT(OUT) :: vout1(:) - complex(DP), OPTIONAL, INTENT(OUT) :: vout2(:) - complex(DP), INTENT(IN) :: vin(:) + COMPLEX(DP), INTENT(OUT) :: vout1(:) + COMPLEX(DP), OPTIONAL, INTENT(OUT) :: vout2(:) + COMPLEX(DP), INTENT(IN) :: vin(:) COMPLEX(DP) :: fp, fm INTEGER :: ig ! @@ -651,6 +651,27 @@ CONTAINS END IF END SUBROUTINE fftx_psi2c_gamma_gpu ! + !------------------------------------------------------------ + SUBROUTINE fftx_psi2c_k( desc, vin, vout, igk ) + !--------------------------------------------------------- + ! + USE fft_types, ONLY : fft_type_descriptor + ! + TYPE(fft_type_descriptor), INTENT(IN) :: desc + COMPLEX(DP), INTENT(IN) :: vin(:) + COMPLEX(DP), INTENT(OUT) :: vout(:) + INTEGER, INTENT(IN) :: igk(:) + ! + INTEGER :: ig + ! + DO ig = 1, desc%ngw + vout(ig) = vin(desc%nl(igk(ig))) + ENDDO + ! + RETURN + ! + END SUBROUTINE fftx_psi2c_k + ! !-------------------------------------------------------------------- SUBROUTINE c2psi_gamma_tg( desc, psis, c_bgrp, i, nbsp_bgrp ) !----------------------------------------------------------------- diff --git a/Modules/fft_wave.f90 b/Modules/fft_wave.f90 index 7206e449f..1cd413e11 100644 --- a/Modules/fft_wave.f90 +++ b/Modules/fft_wave.f90 @@ -13,7 +13,7 @@ MODULE fft_wave !! This module contains wrapper to FFT and inverse FFTs of w.f. ! USE kinds, ONLY: DP - USE fft_interfaces, ONLY: invfft + USE fft_interfaces, ONLY: fwfft, invfft USE fft_types, ONLY: fft_type_descriptor USE control_flags, ONLY: gamma_only ! @@ -21,9 +21,50 @@ MODULE fft_wave ! PRIVATE ! - PUBLIC :: wave_g2r, tgwave_g2r + PUBLIC :: wave_r2g, wave_g2r, tgwave_g2r ! CONTAINS + ! + ! + !---------------------------------------------------------------------- + SUBROUTINE wave_r2g( f_in, f_out, dfft, igk ) + !-------------------------------------------------------------------- + !! Wave function FFT from R to G-space. + ! + USE fft_helper_subroutines, ONLY: fftx_psi2c_gamma, fftx_psi2c_k + ! + IMPLICIT NONE + ! + TYPE(fft_type_descriptor), INTENT(IN) :: dfft + COMPLEX(DP), INTENT(IN) :: f_in(:) + COMPLEX(DP), INTENT(OUT) :: f_out(:,:) + INTEGER, OPTIONAL, INTENT(IN) :: igk(:) + ! + ! ... local variables + ! + COMPLEX(DP), ALLOCATABLE :: psic(:) + INTEGER :: dim2, nrxxs + ! + nrxxs = SIZE(f_in) + dim2 = SIZE(f_out(1,:)) + ! + ALLOCATE( psic(nrxxs) ) + psic = f_in + ! + CALL fwfft( 'Wave', psic, dfft ) + ! + IF (gamma_only) THEN + IF (dim2==1) CALL fftx_psi2c_gamma( dfft, psic, f_out(:,1) ) + IF (dim2==2) CALL fftx_psi2c_gamma( dfft, psic, f_out(:,1), f_out(:,2) ) + ELSE + CALL fftx_psi2c_k( dfft, psic, f_out(:,1), igk ) + ENDIF + ! + DEALLOCATE( psic ) + ! + RETURN + ! + END SUBROUTINE wave_r2g ! ! !---------------------------------------------------------------------- @@ -80,6 +121,7 @@ CONTAINS ! END SUBROUTINE wave_g2r ! + ! !---------------------------------------------------------------------- SUBROUTINE tgwave_g2r( f_in, f_out, dfft, ibnd, ibnd_end, igk ) !-------------------------------------------------------------------- diff --git a/PW/src/h_psi_meta.f90 b/PW/src/h_psi_meta.f90 index e1b90365b..1681ab7d2 100644 --- a/PW/src/h_psi_meta.f90 +++ b/PW/src/h_psi_meta.f90 @@ -22,7 +22,7 @@ SUBROUTINE h_psi_meta( ldap, np, mp, psip, hpsi ) USE control_flags, ONLY : gamma_only USE wavefunctions, ONLY : psic USE fft_base, ONLY : dffts - USE fft_wave, ONLY : wave_g2r + USE fft_wave, ONLY : wave_r2g, wave_g2r USE fft_interfaces, ONLY : fwfft ! IMPLICIT NONE @@ -40,58 +40,49 @@ SUBROUTINE h_psi_meta( ldap, np, mp, psip, hpsi ) ! ! ... local variables ! - REAL(DP), ALLOCATABLE :: kplusg(:) - INTEGER :: im, j, nrxxs - - INTEGER :: i, ebnd, brange - REAL(DP) :: kplusgi - - COMPLEX(DP), ALLOCATABLE :: kplusg_evc(:,:) + COMPLEX(DP), ALLOCATABLE :: psi_g(:,:) + INTEGER :: im, i, j, nrxxs, ebnd, brange + REAL(DP) :: kplusgi, fac COMPLEX(DP), PARAMETER :: ci=(0.d0,1.d0) ! CALL start_clock( 'h_psi_meta' ) ! nrxxs = dffts%nnr - ALLOCATE( kplusg(np) ) - - ALLOCATE( kplusg_evc(np,2) ) - + ! + ALLOCATE( psi_g(np,2) ) ! IF (gamma_only) THEN ! - ! ... gamma algorithm + ! ... Gamma algorithm ! DO im = 1, mp, 2 + ! + fac = 1.d0 + IF ( im < mp ) fac = 0.5d0 + ! DO j = 1, 3 ! DO i = 1, np kplusgi = (xk(j,current_k)+g(j,i)) * tpiba - kplusg_evc(i,1) = CMPLX(0.D0,kplusgi) * psip(i,im) - IF ( im < mp ) kplusg_evc(i,2) = CMPLX(0.d0,kplusgi) * psip(i,im+1) + psi_g(i,1) = CMPLX(0.D0,kplusgi) * psip(i,im) + IF ( im < mp ) psi_g(i,2) = CMPLX(0.d0,kplusgi) * psip(i,im+1) ENDDO ! ebnd = im IF ( im < mp ) ebnd = ebnd + 1 brange = ebnd-im+1 ! - CALL wave_g2r( kplusg_evc(1:np,1:brange), psic, dffts ) + CALL wave_g2r( psi_g(1:np,1:brange), psic, dffts ) ! psic(1:nrxxs) = kedtau(1:nrxxs,current_spin) * psic(1:nrxxs) ! - CALL fwfft( 'Wave', psic, dffts ) + CALL wave_r2g( psic, psi_g(:,1:brange), dffts ) ! - - kplusg (1:np) = (xk(j,current_k)+g(j,1:np)) * tpiba - - IF ( im < mp ) THEN - hpsi(1:np,im) = hpsi(1:np,im) - ci * kplusg(1:np) * 0.5d0 * & - ( psic(dffts%nl(1:np)) + CONJG(psic(dffts%nlm(1:np))) ) - hpsi(1:np,im+1) = hpsi(1:np,im+1) - kplusg(1:np) * 0.5d0 * & - ( psic(dffts%nl(1:np)) - CONJG(psic(dffts%nlm(1:np))) ) - ELSE - hpsi(1:np,im) = hpsi(1:np,im) - ci * kplusg(1:np) * & - psic(dffts%nl(1:np)) - ENDIF + DO i = 1, np + kplusgi = (xk(j,current_k)+g(j,i)) * tpiba + hpsi(i,im) = hpsi(i,im) - ci * kplusgi * fac * psi_g(i,1) + IF ( im < mp ) hpsi(i,im+1) = hpsi(i,im+1) - ci * kplusgi * fac * psi_g(i,2) + ENDDO ! ENDDO ENDDO @@ -105,28 +96,26 @@ SUBROUTINE h_psi_meta( ldap, np, mp, psip, hpsi ) ! DO i = 1, np kplusgi = (xk(j,current_k)+g(j,igk_k(i,current_k))) * tpiba - kplusg_evc(i,1) = CMPLX(0.D0,kplusgi,kind=DP) * psip(i,im) + psi_g(i,1) = CMPLX(0.D0,kplusgi,kind=DP) * psip(i,im) ENDDO ! - CALL wave_g2r( kplusg_evc(1:np,1:1), psic, dffts, igk=igk_k(:,current_k) ) + CALL wave_g2r( psi_g(:,1:1), psic, dffts, igk=igk_k(:,current_k) ) ! psic(1:nrxxs) = kedtau(1:nrxxs,current_spin) * psic(1:nrxxs) ! - - kplusg (1:np) = (xk(j,current_k)+g(j,igk_k(1:np,current_k)))*tpiba - - CALL fwfft( 'Wave', psic, dffts ) + CALL wave_r2g( psic, psi_g(:,1:1), dffts, igk=igk_k(:,current_k) ) + ! + DO i = 1, np + kplusgi = (xk(j,current_k)+g(j,i)) * tpiba + hpsi(i,im) = hpsi(i,im) - CMPLX(0.D0,kplusgi,KIND=DP) * psi_g(i,1) + ENDDO ! - hpsi(1:np,im) = hpsi(1:np,im) - CMPLX(0d0, kplusg(1:np), KIND=DP) & - * psic(dffts%nl(igk_k(1:np,current_k))) ENDDO ENDDO ! ENDIF ! - DEALLOCATE( kplusg_evc ) - - DEALLOCATE( kplusg ) + DEALLOCATE( psi_g ) ! CALL stop_clock( 'h_psi_meta' ) !