quantum-espresso/UtilXlib/tests/test_mp_alltoall_gpu.tmpl

88 lines
2.5 KiB
Cheetah

! Implemented: it, ct
#if defined(__CUDA)
PROGRAM test_mp_alltoall_{vname}_gpu
!
! Simple program to check the functionalities of mp_alltoall.
!
!
USE cudafor
USE parallel_include
USE util_param, ONLY : DP
USE mp, ONLY : mp_alltoall
USE mp_world, ONLY : mp_world_start, mp_world_end, mpime, &
root, nproc, world_comm
USE tester
IMPLICIT NONE
!
TYPE(tester_t) :: test
INTEGER :: rnk, valid_sum, world_group = 0
INTEGER, PARAMETER :: datasize = {datasize}
!
! Stuff for comparing with CPU implementation
integer :: i
REAL(DP), ALLOCATABLE :: rnd{all}
!
! test variable
{type}, ALLOCATABLE, DEVICE :: snd_{vname}_d{all}
{type}, ALLOCATABLE, DEVICE :: rcv_{vname}_d{all}
{type}, ALLOCATABLE :: snd_{vname}_h{all}
{type}, ALLOCATABLE :: rcv_{vname}_h{all}
{type}, ALLOCATABLE :: aux_h{all}
!
CALL test%init()
#if defined(__MPI)
world_group = MPI_COMM_WORLD
#endif
CALL mp_world_start(world_group)
!
ALLOCATE(rnd(datasize, datasize, nproc))
ALLOCATE(snd_{vname}_d(datasize, datasize, nproc))
ALLOCATE(rcv_{vname}_d(datasize, datasize, nproc))
ALLOCATE(snd_{vname}_h(datasize, datasize, nproc))
ALLOCATE(rcv_{vname}_h(datasize, datasize, nproc))
ALLOCATE(aux_h(datasize, datasize, nproc))
snd_{vname}_h = mpime + 1
snd_{vname}_d = snd_{vname}_h
!mp_alltoall (sndbuf, rcvbuf, gid)
CALL mp_alltoall(snd_{vname}_d, rcv_{vname}_d, world_comm)
rcv_{vname}_h = rcv_{vname}_d
!
DO i = 1, nproc
CALL test%assert_equal(SUM(rcv_{vname}_h(:,:,i)) , {typeconv} (i*(datasize**2)))
END DO
!
!
! Test against CPU implementation
CALL save_random_seed("test_mp_alltoall_{vname}_gpu", mpime)
!
CALL RANDOM_NUMBER(rnd)
snd_{vname}_h = {typeconv} ( 10.0 * rnd )
snd_{vname}_d = snd_{vname}_h
aux_h = 0
CALL mp_alltoall(snd_{vname}_d, rcv_{vname}_d, world_comm)
CALL mp_alltoall(snd_{vname}_h, rcv_{vname}_h, world_comm)
aux_h = rcv_{vname}_d
!
CALL test%assert_equal( SUM(rcv_{vname}_h) , SUM(aux_h) )
!
DEALLOCATE(rnd, snd_{vname}_d, snd_{vname}_h)
DEALLOCATE(rcv_{vname}_d, rcv_{vname}_h)
!
CALL collect_results(test)
!
CALL mp_world_end()
!
IF (mpime .eq. 0) CALL test%print()
!
END PROGRAM test_mp_alltoall_{vname}_gpu
#else
PROGRAM test_mp_alltoall_{vname}_gpu
CALL no_test()
END PROGRAM test_mp_alltoall_{vname}_gpu
#endif