Replaced a couple of copies with devxlib API

This commit is contained in:
Pietro Bonfa 2020-12-20 14:13:22 +01:00
parent d66616e746
commit 7b88b45d94
1 changed files with 7 additions and 31 deletions

View File

@ -20,6 +20,7 @@ SUBROUTINE rotate_HSpsi_k_gpu( npwx, npw, nstart, nbnd, npol, psi_d, hpsi_d, ove
USE mp_bands_util, ONLY : intra_bgrp_comm, inter_bgrp_comm, root_bgrp_id, nbgrp, my_bgrp_id, &
me_bgrp, root_bgrp
USE mp, ONLY : mp_sum, mp_barrier, mp_allgather, mp_type_create_column_section, mp_type_free
USE device_memcpy_m, ONLY: dev_memcpy, dev_memset
!
IMPLICIT NONE
!
@ -159,10 +160,7 @@ SUBROUTINE rotate_HSpsi_k_gpu( npwx, npw, nstart, nbnd, npol, psi_d, hpsi_d, ove
!
call start_clock('rotHSw:diag'); !write(*,*) 'start rotHSw:diag' ; FLUSH(6)
CALL diaghg( nstart, nbnd, hh_d, ss_d, nstart, en_d, vv_d, me_bgrp, root_bgrp, intra_bgrp_comm )
!$cuf kernel do(1)
DO ii = 1, nbnd
e_d(ii) = en_d(ii)
END DO
CALL dev_memcpy(e_d, en_d, [1,nbnd])
call stop_clock('rotHSw:diag'); !write(*,*) 'stop rotHSw:diag' ; FLUSH(6)
!
! ... update the basis set
@ -178,12 +176,8 @@ SUBROUTINE rotate_HSpsi_k_gpu( npwx, npw, nstart, nbnd, npol, psi_d, hpsi_d, ove
if (n_start .le. n_end) &
CALL gpu_ZGEMM( 'N','N', kdim, my_n, nstart, (1.D0,0.D0), psi_d, kdmx, vv_d(1,n_start), nstart, &
(0.D0,0.D0), aux_d(1,n_start), kdmx )
!$cuf kernel do(2)
DO ii = 1, kdmx
DO jj = n_start, n_end
psi_d(ii,jj) = aux_d(ii, jj)
END DO
END DO
CALL dev_memcpy(psi_d, aux_d, [1, kdmx], 1, [n_start,n_end])
!call start_clock('rotHSw:ev:b3'); CALL mp_barrier( inter_bgrp_comm ); call stop_clock('rotHSw:ev:b3')
call start_clock('rotHSw:ev:s5')
CALL mp_allgather(psi_d(:,1:nbnd), column_type, recv_counts, displs, inter_bgrp_comm)
@ -192,13 +186,7 @@ SUBROUTINE rotate_HSpsi_k_gpu( npwx, npw, nstart, nbnd, npol, psi_d, hpsi_d, ove
if (n_start .le. n_end) &
CALL gpu_ZGEMM( 'N','N', kdim, my_n, nstart, (1.D0,0.D0), hpsi_d, kdmx, vv_d(1,n_start), nstart, &
(0.D0,0.D0), aux_d(1,n_start), kdmx )
!$cuf kernel do(2)
DO ii = 1, kdmx
DO jj = n_start, n_end
hpsi_d(ii,jj) = aux_d(ii,jj)
END DO
END DO
!call start_clock('rotHSw:ev:b4'); CALL mp_barrier( inter_bgrp_comm ); call stop_clock('rotHSw:ev:b4')
CALL dev_memcpy(hpsi_d, aux_d, [1, kdmx], 1, [n_start,n_end]) !call start_clock('rotHSw:ev:b4'); CALL mp_barrier( inter_bgrp_comm ); call stop_clock('rotHSw:ev:b4')
call start_clock('rotHSw:ev:s6')
CALL mp_allgather(hpsi_d(:,1:nbnd), column_type, recv_counts, displs, inter_bgrp_comm)
call stop_clock('rotHSw:ev:s6')
@ -207,26 +195,14 @@ SUBROUTINE rotate_HSpsi_k_gpu( npwx, npw, nstart, nbnd, npol, psi_d, hpsi_d, ove
if (n_start .le. n_end) &
CALL gpu_ZGEMM( 'N','N', kdim, my_n, nstart, (1.D0,0.D0), spsi_d, kdmx, vv_d(1,n_start), &
nstart, (0.D0,0.D0), aux_d(1,n_start), kdmx )
!$cuf kernel do(2)
DO ii = 1, kdmx
DO jj = n_start, n_end
spsi_d(ii,jj) = aux_d(ii,jj)
END DO
END DO
!call start_clock('rotHSw:ev:b5'); CALL mp_barrier( inter_bgrp_comm ); call stop_clock('rotHSw:ev:b5')
CALL dev_memcpy(spsi_d, aux_d, [1, kdmx], 1, [n_start,n_end]) !call start_clock('rotHSw:ev:b5'); CALL mp_barrier( inter_bgrp_comm ); call stop_clock('rotHSw:ev:b5')
call start_clock('rotHSw:ev:s7')
CALL mp_allgather(spsi_d(:,1:nbnd), column_type, recv_counts, displs, inter_bgrp_comm)
call stop_clock('rotHSw:ev:s7')
ELSE IF (present(spsi_d)) THEN
!$cuf kernel do(2)
DO ii = 1, kdmx
DO jj = 1, nbnd
spsi_d(ii,jj) = psi_d(ii,jj)
END DO
END DO
CALL dev_memcpy(spsi_d, psi_d, [1, kdmx], 1, [n_start,n_end])
END IF
DEALLOCATE( aux_d )