Added an experimental version with double-buffering for the pack_group_sticks. The target is to overlap the communication therein with the computation of the FFTs. It can be activated adding -D__DOUBLE_BUFFER

git-svn-id: http://qeforge.qe-forge.org/svn/q-e/trunk/espresso@12233 c92efa57-630b-4861-b058-cf58834340f0
This commit is contained in:
faffinito 2016-03-21 11:20:17 +00:00
parent 6bb6019cd8
commit 9e9ded0412
2 changed files with 101 additions and 9 deletions

View File

@ -382,6 +382,58 @@ SUBROUTINE bw_tg_cft3_xy( f, dfft )
!
END SUBROUTINE bw_tg_cft3_xy
#ifdef __DOUBLE_BUFFER
SUBROUTINE pack_group_sticks_i( f, yf, dfft, req)
USE fft_types, ONLY : fft_dlay_descriptor
IMPLICIT NONE
#if defined(__MPI)
INCLUDE 'mpif.h'
#endif
COMPLEX(DP), INTENT(in) :: f( : ) ! array containing all bands, and gvecs distributed across processors
COMPLEX(DP), INTENT(out) :: yf( : ) ! array containing bands collected into task groups
TYPE (fft_dlay_descriptor), INTENT(in) :: dfft
INTEGER :: ierr,req
!
IF( dfft%tg_rdsp(dfft%nogrp) + dfft%tg_rcv(dfft%nogrp) > size( yf ) ) THEN
CALL fftx_error__( 'pack_group_sticks' , ' inconsistent size ', 1 )
ENDIF
IF( dfft%tg_psdsp(dfft%nogrp) + dfft%tg_snd(dfft%nogrp) > size( f ) ) THEN
CALL fftx_error__( 'pack_group_sticks', ' inconsistent size ', 2 )
ENDIF
CALL start_clock( 'IALLTOALL' )
!
! Collect all the sticks of the different states,
! in "yf" processors will have all the sticks of the OGRP
#if defined(__MPI)
CALL MPI_IALLTOALLV( f(1), dfft%tg_snd, dfft%tg_psdsp, MPI_DOUBLE_COMPLEX, yf(1), dfft%tg_rcv, &
& dfft%tg_rdsp, MPI_DOUBLE_COMPLEX, dfft%ogrp_comm, req, ierr)
IF( ierr /= 0 ) THEN
CALL fftx_error__( 'pack_group_sticks_i', ' alltoall error 1 ', abs(ierr) )
ENDIF
#else
IF( dfft%tg_rcv(dfft%nogrp) /= dfft%tg_snd(dfft%nogrp) ) THEN
CALL fftx_error__( 'pack_group_sticks', ' inconsistent size ', 3 )
ENDIF
yf( 1 : dfft%tg_rcv(dfft%nogrp) ) = f( 1 : dfft%tg_snd(dfft%nogrp) )
#endif
CALL stop_clock( 'IALLTOALL' )
!
!YF Contains all ( ~ NOGRP*dfft%nsw(me) ) Z-sticks
!
RETURN
END SUBROUTINE pack_group_sticks_i
#endif
!----------------------------------------------------------------------------
SUBROUTINE pack_group_sticks( f, yf, dfft )

View File

@ -9,7 +9,7 @@ program test
include 'fft_param.f90'
INTEGER, ALLOCATABLE :: req_p(:),req_u(:)
#endif
TYPE(fft_dlay_descriptor) :: dfftp, dffts, dfft3d,dfftsnow,dfftsnext
TYPE(fft_dlay_descriptor) :: dfftp, dffts, dfft3d
INTEGER :: nx = 128
INTEGER :: ny = 128
INTEGER :: nz = 256
@ -219,9 +219,6 @@ program test
gamma_only = .true.
stdout = 6
dfftsnow=dffts
dfftsnext=dffts
CALL pstickset( gamma_only, bg, gcutm, gkcut, gcutms, &
dfftp, dffts, ngw_ , ngm_ , ngs_ , mype, root, &
@ -282,21 +279,21 @@ program test
! Execute FFT calls once more and Take time
!
ncount = 0
! Copie provvisorie: CHECK
!
tempo(10) = MPI_WTIME()
!
#ifdef __DOUBLE_BUFFER
ireq = 1
ipsi = MOD( ireq + 1, 2 ) + 1
!
CALL pack_group_sticks_i( aux, psis(:, ipsi ), dffts, req_p( ireq ) )
!
nreq = 0
DO ib = 1, nbnd, 2*dffts%nogrp ! <- originale. non funziona
DO ib = 1, nbnd, 2*dffts%nogrp
nreq = nreq + 1
END DO
!
DO ib = 1, nbnd, 2*dffts%nogrp ! <- originale. non funziona
DO ib = 1, nbnd, 2*dffts%nogrp
ireq = ireq + 1
@ -316,6 +313,48 @@ program test
tempo(2) = MPI_WTIME()
CALL fw_tg_cft3_z( psis( :, ipsi ), dffts, aux )
tempo(3) = MPI_WTIME()
CALL fw_tg_cft3_scatter( psis( :, ipsi ), dffts, aux )
tempo(4) = MPI_WTIME()
CALL fw_tg_cft3_xy( psis( :, ipsi ), dffts )
tempo(5) = MPI_WTIME()
!
tmp1=1.d0
tmp2=0.d0
CALL DAXPY(10000, pi, tmp1, 1, tmp2, 1)
!
CALL bw_tg_cft3_xy( psis( :, ipsi ), dffts )
tempo(6) = MPI_WTIME()
CALL bw_tg_cft3_scatter( psis( :, ipsi ), dffts, aux )
tempo(7) = MPI_WTIME()
CALL bw_tg_cft3_z( psis( :, ipsi ), dffts, aux )
tempo(8) = MPI_WTIME()
!
CALL unpack_group_sticks( psis( :, ipsi ), aux, dffts )
!
tempo(9) = MPI_WTIME()
!
do i = 2, 10
tempo_mio(i) = tempo_mio(i) + (tempo(i) - tempo(i-1))
end do
!
ncount = ncount + 1
!
enddo
#else
ipsi = 1
!
DO ib = 1, nbnd, 2*dffts%nogrp
aux = 0.0d0
aux(1) = 1.0d0
tempo(1) = MPI_WTIME()
CALL pack_group_sticks( aux, psis(:,ipsi), dffts )
tempo(2) = MPI_WTIME()
CALL fw_tg_cft3_z( psis( :, ipsi ), dffts, aux )
tempo(3) = MPI_WTIME()
CALL fw_tg_cft3_scatter( psis( :, ipsi ), dffts, aux )
@ -345,6 +384,7 @@ program test
ncount = ncount + 1
enddo
#endif
tempo(11) = MPI_WTIME()