- further OpenMP parallelization

git-svn-id: http://qeforge.qe-forge.org/svn/q-e/trunk/espresso@5565 c92efa57-630b-4861-b058-cf58834340f0
This commit is contained in:
ccavazzoni 2009-05-23 16:23:17 +00:00
parent a804f0ff18
commit aa20220127
2 changed files with 98 additions and 62 deletions

View File

@ -99,8 +99,9 @@ subroutine fft_scatter ( f_in, nrx3, nxx_, f_aux, ncp_, npp_, sign, use_tg )
#ifdef __PARA
INTEGER :: dest, from, k, offset, proc, ierr, me, nprocp, gproc, gcomm, i, kdest, kfrom
INTEGER :: sendcount (nproc_pool), sdispls, recvcount (nproc_pool), rdispls
INTEGER :: dest, from, k, ip, proc, ierr, me, me_pgrp, nprocp, gproc, gcomm, i, kdest, kfrom
INTEGER :: sendcount(nproc_pool), sdispls(nproc_pool), recvcount(nproc_pool), rdispls(nproc_pool)
INTEGER :: offset(nproc_pool)
INTEGER :: sh(nproc_pool), rh(nproc_pool)
!
LOGICAL :: use_tg_ , lrcv, lsnd
@ -123,13 +124,11 @@ subroutine fft_scatter ( f_in, nrx3, nxx_, f_aux, ncp_, npp_, sign, use_tg )
IF( use_tg_ ) THEN
! This is the number of procs. in the plane-wave group
nprocp = nproc_pool / nogrp
ELSE
nprocp = nproc_pool
END IF
!
IF( use_tg_ ) THEN
CALL mpi_comm_rank( pgrp_comm, me_pgrp, ierr )
gcomm = pgrp_comm
ELSE
nprocp = nproc_pool
me_pgrp = me_pool
gcomm = intra_pool_comm
END IF
!
@ -144,6 +143,35 @@ subroutine fft_scatter ( f_in, nrx3, nxx_, f_aux, ncp_, npp_, sign, use_tg )
! sdispls+1 is the beginning of data that must be sent to proc
! rdispls+1 is the beginning of data that must be received from pr
!
IF( use_tg_ ) THEN
do proc = 1, nprocp
gproc = nplist( proc ) + 1
sendcount (proc) = npp_ ( gproc ) * ncp_ (me)
recvcount (proc) = npp_ (me) * ncp_ ( gproc )
end do
offset(1) = 0
do proc = 2, nprocp
gproc = nplist( proc - 1 ) + 1
offset(proc) = offset(proc - 1) + npp_ ( gproc )
end do
ELSE
do proc = 1, nprocp
sendcount (proc) = npp_ (proc) * ncp_ (me)
recvcount (proc) = npp_ (me) * ncp_ (proc)
end do
offset(1) = 0
do proc = 2, nprocp
offset(proc) = offset(proc - 1) + npp_ (proc - 1)
end do
END IF
!
sdispls (1) = 0
rdispls (1) = 0
do proc = 2, nprocp
sdispls (proc) = sdispls (proc - 1) + sendcount (proc - 1)
rdispls (proc) = rdispls (proc - 1) + recvcount (proc - 1)
enddo
!
ierr = 0
!
if ( sign > 0 ) then
@ -152,21 +180,21 @@ subroutine fft_scatter ( f_in, nrx3, nxx_, f_aux, ncp_, npp_, sign, use_tg )
!
! step one: store contiguously the slices
!
offset = 1
sdispls = 0
!
do proc = 1, nprocp
do ip = 1, nprocp
! the following two lines make the loop iterations different on each
! proc in order to avoid that all procs send a msg at the same proc
! at the same time.
!
proc = me_pgrp + 1 + ip
IF( proc > nprocp ) proc = proc - nprocp
gproc = proc
IF( use_tg_ ) gproc = nplist( proc ) + 1
!
sendcount (proc) = npp_ ( gproc ) * ncp_ (me)
from = offset
dest = 1 + sdispls
from = 1 + offset( proc )
dest = 1 + sdispls( proc )
!
! optimize for large parallel execution, where npp_ ( gproc ) ~ 1
!
SELECT CASE ( npp_ ( gproc ) )
@ -193,6 +221,7 @@ subroutine fft_scatter ( f_in, nrx3, nxx_, f_aux, ncp_, npp_, sign, use_tg )
f_aux ( dest + (k - 1) * 4 - 1 + 4 ) = f_in ( from + (k - 1) * nrx3 - 1 + 4 )
enddo
CASE DEFAULT
!$omp parallel do default(shared), private(i, kdest, kfrom)
do k = 1, ncp_ (me)
kdest = dest + (k - 1) * npp_ ( gproc ) - 1
kfrom = from + (k - 1) * nrx3 - 1
@ -204,35 +233,28 @@ subroutine fft_scatter ( f_in, nrx3, nxx_, f_aux, ncp_, npp_, sign, use_tg )
!
! post the non-blocking send, f_aux can't be overwritten until operation has completed
!
call mpi_isend( f_aux( sdispls + 1 ), sendcount( proc ), MPI_DOUBLE_COMPLEX, &
call mpi_isend( f_aux( sdispls( proc ) + 1 ), sendcount( proc ), MPI_DOUBLE_COMPLEX, &
proc-1, me, gcomm, sh( proc ), ierr )
!
if( ABS(ierr) /= 0 ) call errore ('fft_scatter', ' forward send info<>0', ABS(ierr) )
!
offset = offset + npp_ ( gproc )
sdispls = sdispls + sendcount (proc)
!
end do
!
! step two: communication
!
rdispls = 0
!
do proc = 1, nprocp
do ip = 1, nprocp
!
gproc = proc
IF( use_tg_ ) gproc = nplist( proc ) + 1
!
recvcount (proc) = npp_ (me) * ncp_ ( gproc )
proc = me_pgrp + 1 - ip
IF( proc < 1 ) proc = proc + nprocp
!
! now post the receive
!
CALL mpi_irecv( f_in( rdispls + 1 ), recvcount( proc ), MPI_DOUBLE_COMPLEX, &
CALL mpi_irecv( f_in( rdispls( proc ) + 1 ), recvcount( proc ), MPI_DOUBLE_COMPLEX, &
proc-1, MPI_ANY_TAG, gcomm, rh( proc ), ierr )
!
if( ABS(ierr) /= 0 ) call errore ('fft_scatter', ' forward receive info<>0', ABS(ierr) )
!
rdispls = rdispls + recvcount (proc)
tstr( proc ) = .false.
tsts( proc ) = .false.
!
@ -240,7 +262,7 @@ subroutine fft_scatter ( f_in, nrx3, nxx_, f_aux, ncp_, npp_, sign, use_tg )
!
! maybe useless; ensures that no garbage is present in the output
!
f_in( rdispls + 1 : SIZE( f_in ) ) = 0.0_DP
f_in( rdispls( nprocp ) + recvcount( nprocp ) + 1 : SIZE( f_in ) ) = 0.0_DP
!
lrcv = .false.
lsnd = .false.
@ -264,41 +286,35 @@ subroutine fft_scatter ( f_in, nrx3, nxx_, f_aux, ncp_, npp_, sign, use_tg )
lsnd = lsnd .and. tsts( proc )
!
end do
!
!
end do
!
else
!
! "backward" scatter from planes to columns
!
sdispls = 0
rdispls = 0
!
do proc = 1, nprocp
gproc = proc
IF( use_tg_ ) gproc = nplist( proc ) + 1
sendcount (proc) = npp_ ( gproc ) * ncp_ (me)
recvcount (proc) = npp_ (me) * ncp_ ( gproc )
do ip = 1, nprocp
! post the non blocking send
call mpi_isend( f_in( rdispls + 1 ), recvcount( proc ), MPI_DOUBLE_COMPLEX, &
proc = me_pgrp + 1 + ip
IF( proc > nprocp ) proc = proc - nprocp
call mpi_isend( f_in( rdispls( proc ) + 1 ), recvcount( proc ), MPI_DOUBLE_COMPLEX, &
proc-1, me, gcomm, sh( proc ), ierr )
if( ABS(ierr) /= 0 ) call errore ('fft_scatter', ' backward send info<>0', ABS(ierr) )
! post the non blocking receive
CALL mpi_irecv( f_aux( sdispls + 1 ), sendcount( proc ), MPI_DOUBLE_COMPLEX, &
proc = me_pgrp + 1 - ip
IF( proc < 1 ) proc = proc + nprocp
CALL mpi_irecv( f_aux( sdispls( proc ) + 1 ), sendcount( proc ), MPI_DOUBLE_COMPLEX, &
proc-1, MPI_ANY_TAG, gcomm, rh(proc), ierr )
if( ABS(ierr) /= 0 ) call errore ('fft_scatter', ' backward receive info<>0', ABS(ierr) )
sdispls = sdispls + sendcount (proc)
rdispls = rdispls + recvcount (proc)
tstr( 1 : proc ) = .false.
tsts( 1 : proc ) = .false.
tstr( ip ) = .false.
tsts( ip ) = .false.
end do
!
@ -323,13 +339,12 @@ subroutine fft_scatter ( f_in, nrx3, nxx_, f_aux, ncp_, npp_, sign, use_tg )
end do
!
lrcv = .false.
!
do while ( .not. lrcv )
!
lrcv = .true.
!
offset = 1
sdispls = 0
!
do proc = 1, nprocp
gproc = proc
@ -341,8 +356,8 @@ subroutine fft_scatter ( f_in, nrx3, nxx_, f_aux, ncp_, npp_, sign, use_tg )
IF( tstr( proc ) ) THEN
from = 1 + sdispls
dest = offset
from = 1 + sdispls( proc )
dest = 1 + offset( proc )
!
! optimize for large parallel execution, where npp_ ( gproc ) ~ 1
!
@ -370,6 +385,7 @@ subroutine fft_scatter ( f_in, nrx3, nxx_, f_aux, ncp_, npp_, sign, use_tg )
f_in ( dest + (k - 1) * nrx3 - 1 + 4 ) = f_aux( from + (k - 1) * 4 - 1 + 4 )
enddo
CASE DEFAULT
!$omp parallel do default(shared), private(i, kdest, kfrom)
do k = 1, ncp_ ( me )
kdest = dest + (k - 1) * nrx3 - 1
kfrom = from + (k - 1) * npp_ ( gproc ) - 1
@ -385,9 +401,6 @@ subroutine fft_scatter ( f_in, nrx3, nxx_, f_aux, ncp_, npp_, sign, use_tg )
lrcv = lrcv .and. tstr( proc )
offset = offset + npp_ ( gproc )
sdispls = sdispls + sendcount (proc)
end do
end do

View File

@ -319,7 +319,7 @@ CONTAINS
use fft_base, only: fft_scatter
!
INTEGER, INTENT(IN) :: iopt
INTEGER :: nppx, ip, nnp, npp, ii, i, mc, j
INTEGER :: nppx, ip, nnp, npp, ii, i, mc, j, ioff
!
IF( iopt == 2 ) THEN
!
@ -342,17 +342,26 @@ CONTAINS
!
END IF
!
f(:) = (0.d0, 0.d0)
!
!$omp parallel default(shared), private( ii, mc, j, i, ioff, ip )
!$omp do
do i = 1, SIZE( f )
f(i) = (0.d0, 0.d0)
end do
!
ii = 0
!
do ip = 1, nproc_pool
!
ioff = dfft%iss( ip )
!
do i = 1, dfft%nsw( ip )
!
mc = dfft%ismap( i + dfft%iss( ip ) )
mc = dfft%ismap( i + ioff )
!
ii = ii + 1
!
!$omp do
do j = 1, npp
f( mc + ( j - 1 ) * nnp ) = aux( j + ( ii - 1 ) * nppx )
end do
@ -360,6 +369,7 @@ CONTAINS
end do
!
end do
!$omp end parallel
!
ELSE IF( iopt == 1 ) THEN
!
@ -371,14 +381,20 @@ CONTAINS
!
call fft_scatter( aux, nx3, dfft%nnr, f, dfft%nsp, dfft%npp, iopt )
!
f(:) = (0.d0, 0.d0)
!$omp parallel default(shared)
!$omp do
do i = 1, SIZE(f)
f(i) = (0.d0, 0.d0)
end do
!
!$omp do private(mc,j)
do i = 1, dfft%nst
mc = dfft%ismap( i )
do j = 1, dfft%npp( me_p )
f( mc + ( j - 1 ) * dfft%nnp ) = aux( j + ( i - 1 ) * nppx )
end do
end do
!$omp end parallel
!
END IF
!
@ -412,17 +428,20 @@ CONTAINS
!
END IF
ii = 0
!$omp parallel default(shared), private( mc, j, i, ii, ip )
ii = 0
do ip = 1, nproc_pool
do i = 1, dfft%nsw( ip )
mc = dfft%ismap( i + dfft%iss( ip ) )
ii = ii + 1
!$omp do
do j = 1, npp
aux( j + ( ii - 1 ) * nppx ) = f( mc + ( j - 1 ) * nnp )
end do
end do
end do
!$omp end parallel
!
IF( use_tg ) THEN
!
@ -441,12 +460,16 @@ CONTAINS
else
nppx = dfft%npp( me_p )
end if
!$omp parallel default(shared), private( mc, j, i )
!$omp do
do i = 1, dfft%nst
mc = dfft%ismap( i )
do j = 1, dfft%npp( me_p )
aux( j + ( i - 1 ) * nppx ) = f( mc + ( j - 1 ) * dfft%nnp )
end do
end do
!$omp end parallel
!
call fft_scatter( aux, nx3, dfft%nnr, f, dfft%nsp, dfft%npp, iopt )
!
END IF