mirror of https://gitlab.com/QEF/q-e.git
Almost there but not quite yet (fft_scatter_many_yz still failing for large MPI pools). Device selection still missing.
This commit is contained in:
parent
ddfe4181ad
commit
34e9ae5b78
|
@ -33,10 +33,14 @@ program test_fft_scatter_mod_gpu
|
|||
!
|
||||
DO i = 1, mp%n
|
||||
IF (MOD(mp%n,i) == 0 ) THEN
|
||||
! gamma case
|
||||
CALL test_fft_scatter_xy_gpu_1(mp, test, .true., i)
|
||||
! k case
|
||||
CALL test_fft_scatter_xy_gpu_1(mp, test, .false., i)
|
||||
!
|
||||
! gamma case
|
||||
CALL test_fft_scatter_yz_gpu_1(mp, test, .true., i)
|
||||
! k case
|
||||
CALL test_fft_scatter_yz_gpu_1(mp, test, .false., i)
|
||||
END IF
|
||||
END DO
|
||||
|
@ -136,15 +140,19 @@ program test_fft_scatter_mod_gpu
|
|||
COMPLEX(DP), ALLOCATABLE, DEVICE :: scatter_in_d(:), scatter_out_d(:)
|
||||
integer(kind = cuda_stream_kind) :: stream = 0
|
||||
integer :: fft_sign = 2
|
||||
integer :: vsiz
|
||||
integer :: vsiz, nr1p_, compare_len, me2
|
||||
!
|
||||
parallel = mp%n .gt. 1
|
||||
|
||||
CALL fft_desc_init(dfft, smap, "wave", gamma_only, parallel, mp%comm, nyfft=ny)
|
||||
me2 = dfft%mype2 + 1
|
||||
vsiz = dfft%nnr
|
||||
compare_len = dfft%nr1x * dfft%my_nr2p * dfft%my_nr3p
|
||||
if (ny > 1) then
|
||||
! When using task groups, wave FFTs are not distributed along Y
|
||||
fft_sign = 3
|
||||
vsiz = dfft%nnr_tg
|
||||
compare_len = dfft%nr1x * dfft%nr2x * dfft%my_nr3p
|
||||
end if
|
||||
!
|
||||
! Allocate variables
|
||||
|
@ -156,19 +164,24 @@ program test_fft_scatter_mod_gpu
|
|||
!
|
||||
CALL fft_scatter_xy( dfft, scatter_in, scatter_out, vsiz, fft_sign )
|
||||
CALL fft_scatter_xy_gpu( dfft, scatter_in_d, scatter_out_d, vsiz, fft_sign, stream )
|
||||
aux(1:vsiz) = scatter_out_d(1:vsiz)
|
||||
aux(1:compare_len) = scatter_out_d(1:compare_len)
|
||||
!
|
||||
! Check
|
||||
CALL test%assert_close( scatter_out, aux )
|
||||
CALL test%assert_close( scatter_out(1:compare_len), aux(1:compare_len) )
|
||||
!
|
||||
! Test 2
|
||||
CALL fill_random(scatter_in, scatter_in_d, vsiz)
|
||||
!
|
||||
CALL fft_scatter_xy( dfft, scatter_out, scatter_in, vsiz, -1*fft_sign )
|
||||
CALL fft_scatter_xy_gpu( dfft, scatter_out_d, scatter_in_d, vsiz, -1*fft_sign, stream )
|
||||
aux(1:vsiz) = scatter_out_d(1:vsiz)
|
||||
!
|
||||
compare_len = dfft%nr2x * dfft%nr1w(me2) * dfft%my_nr3p
|
||||
IF (ny > 1) compare_len = dfft%nr2x * dfft%nr1w_tg * dfft%my_nr3p
|
||||
!
|
||||
aux(1:compare_len) = scatter_out_d(1:compare_len)
|
||||
! Check
|
||||
CALL test%assert_close( scatter_out, aux )
|
||||
!!
|
||||
CALL test%assert_close( scatter_out(1:compare_len), aux(1:compare_len) )
|
||||
!
|
||||
CALL fft_desc_finalize(dfft, smap)
|
||||
DEALLOCATE(scatter_in, scatter_out, aux, scatter_in_d, scatter_out_d)
|
||||
!
|
||||
|
@ -199,14 +212,16 @@ program test_fft_scatter_mod_gpu
|
|||
COMPLEX(DP), ALLOCATABLE, DEVICE :: scatter_in_d(:), scatter_out_d(:)
|
||||
integer(kind = cuda_stream_kind) :: stream = 0
|
||||
integer :: fft_sign = 2
|
||||
integer :: vsiz
|
||||
integer :: vsiz, compare_len, my_nr1p_
|
||||
!
|
||||
parallel = mp%n .gt. 1
|
||||
CALL fft_desc_init(dfft, smap, "wave", gamma_only, parallel, mp%comm, nyfft=ny)
|
||||
vsiz = dfft%nnr
|
||||
vsiz = dfft%nnr
|
||||
my_nr1p_ = count(dfft%ir1w > 0)
|
||||
if (ny > 1) then
|
||||
fft_sign = 3
|
||||
vsiz = dfft%nnr_tg
|
||||
my_nr1p_ = count(dfft%ir1w_tg > 0)
|
||||
end if
|
||||
!
|
||||
! Allocate variables
|
||||
|
@ -218,18 +233,23 @@ program test_fft_scatter_mod_gpu
|
|||
!
|
||||
CALL fft_scatter_yz( dfft, scatter_in, scatter_out, vsiz, fft_sign )
|
||||
CALL fft_scatter_yz_gpu( dfft, scatter_in_d, scatter_out_d, vsiz, fft_sign )
|
||||
aux(1:vsiz) = scatter_out_d(1:vsiz)
|
||||
! Set the number of elements that should be strictly equivalent in the
|
||||
! two implementations.
|
||||
compare_len = dfft%my_nr3p*my_nr1p_*dfft%nr2x
|
||||
aux(1:compare_len) = scatter_out_d(1:compare_len)
|
||||
! Check
|
||||
CALL test%assert_close( scatter_out, aux )
|
||||
CALL test%assert_close( scatter_out(1:compare_len), aux(1:compare_len) )
|
||||
!
|
||||
! Test 2
|
||||
CALL fill_random(scatter_in, scatter_in_d, vsiz)
|
||||
!
|
||||
CALL fft_scatter_yz( dfft, scatter_out, scatter_in, vsiz, -1*fft_sign )
|
||||
CALL fft_scatter_yz_gpu( dfft, scatter_out_d, scatter_in_d, vsiz, -1*fft_sign )
|
||||
aux(1:vsiz) = scatter_out_d(1:vsiz)
|
||||
!
|
||||
compare_len = dfft%nsw(mp%me+1)*dfft%nr3x
|
||||
aux(1:compare_len) = scatter_out_d(1:compare_len)
|
||||
! Check
|
||||
CALL test%assert_close( scatter_out, aux )
|
||||
CALL test%assert_close( scatter_out(1:compare_len), aux(1:compare_len) )
|
||||
!
|
||||
CALL fft_desc_finalize(dfft, smap)
|
||||
DEALLOCATE(scatter_in, scatter_out, aux, scatter_in_d, scatter_out_d)
|
||||
|
|
Loading…
Reference in New Issue