Remove evc_d from Davidson (use OpenACC evc instead)

This commit is contained in:
Ivan Carnimeo 2024-02-13 11:30:10 +01:00
parent 3cc2488741
commit 023e965bbf
4 changed files with 33 additions and 37 deletions

View File

@ -123,7 +123,7 @@ SUBROUTINE cegterg( h_psi_ptr, s_psi_ptr, uspp, g_psi_ptr, &
nhpsi = 0
CALL start_clock( 'cegterg' ); !write(*,*) 'start cegterg' ; FLUSH(6)
!
!$acc data deviceptr(evc, e)
!$acc data deviceptr(e)
!
IF ( nvec > nvecx / 2 ) CALL errore( 'cegterg', 'nvecx is too small', 1 )
!
@ -182,7 +182,7 @@ SUBROUTINE cegterg( h_psi_ptr, s_psi_ptr, uspp, g_psi_ptr, &
nbase = nvec
conv = .FALSE.
!
!$acc host_data use_device(psi, hpsi, spsi, hc, sc)
!$acc host_data use_device(evc, psi, hpsi, spsi, hc, sc)
CALL dev_memcpy(psi, evc, (/ 1 , npwx*npol /), 1, &
(/ 1 , nvec /), 1)
!
@ -582,11 +582,11 @@ SUBROUTINE cegterg( h_psi_ptr, s_psi_ptr, uspp, g_psi_ptr, &
!
CALL divide(inter_bgrp_comm,nbase,n_start,n_end)
my_n = n_end - n_start + 1; !write (*,*) nbase,n_start,n_end
!$acc host_data use_device(psi, vc)
!$acc host_data use_device(evc, psi, vc)
CALL ZGEMM( 'N','N', kdim, nvec, my_n, ONE, psi(1,n_start), kdmx, vc(n_start,1), nvecx, &
ZERO, evc, kdmx )
!$acc end host_data
CALL mp_sum( evc, inter_bgrp_comm )
!$acc end host_data
!
IF ( notcnv == 0 ) THEN
!
@ -611,7 +611,7 @@ SUBROUTINE cegterg( h_psi_ptr, s_psi_ptr, uspp, g_psi_ptr, &
!
! ... refresh psi, H*psi and S*psi
!
!$acc host_data use_device(psi, hpsi, spsi, vc)
!$acc host_data use_device(evc, psi, hpsi, spsi, vc)
CALL dev_memcpy(psi, evc, (/ 1, npwx*npol /), 1, &
(/ 1, nvec /), 1)
!

View File

@ -112,7 +112,7 @@ SUBROUTINE regterg( h_psi_ptr, s_psi_ptr, uspp, g_psi_ptr, &
!
CALL start_clock( 'regterg' ) !; write(6,*) 'enter regterg' ; FLUSH(6)
!
!$acc data deviceptr(evc, e)
!$acc data deviceptr(e)
!
IF ( nvec > nvecx / 2 ) CALL errore( 'regter', 'nvecx is too small', 1 )
!
@ -509,10 +509,10 @@ SUBROUTINE regterg( h_psi_ptr, s_psi_ptr, uspp, g_psi_ptr, &
!
CALL divide(inter_bgrp_comm,nbase,n_start,n_end)
my_n = n_end - n_start + 1; !write (*,*) nbase,n_start,n_end
!$acc host_data use_device(psi, vr)
!$acc host_data use_device(evc, psi, vr)
CALL DGEMM( 'N','N', npw2, nvec, my_n, 1.D0, psi(1,n_start), npwx2, vr(n_start,1), nvecx, 0.D0, evc, npwx2 )
!$acc end host_data
CALL mp_sum( evc, inter_bgrp_comm )
!$acc end host_data
!
IF ( notcnv == 0 ) THEN
!

View File

@ -56,7 +56,6 @@ SUBROUTINE c_bands( iter )
!
!
CALL start_clock( 'c_bands' ); !write (*,*) 'start c_bands' ; FLUSH(6)
CALL using_evc(0)
!
ik_ = 0
avg_iter = 0.D0
@ -67,9 +66,10 @@ SUBROUTINE c_bands( iter )
! ... directly from file, in order to avoid wasting memory)
!
DO ik = 1, ik_
IF ( nks > 1 .OR. lelfield ) &
IF ( nks > 1 .OR. lelfield ) THEN
CALL get_buffer ( evc, nwordwfc, iunwfc, ik )
IF ( nks > 1 .OR. lelfield ) CALL using_evc(1)
!$acc update device(evc)
END IF
ENDDO
!
IF ( isolve == 0 ) THEN
@ -114,9 +114,10 @@ SUBROUTINE c_bands( iter )
!
! ... read in wavefunctions from the previous iteration
!
IF ( nks > 1 .OR. lelfield ) &
IF ( nks > 1 .OR. lelfield ) THEN
CALL get_buffer ( evc, nwordwfc, iunwfc, ik )
IF ( nks > 1 .OR. lelfield ) CALL using_evc(2)
!$acc update device(evc)
END IF
!
! ... Needed for DFT+Hubbard
!
@ -128,15 +129,15 @@ SUBROUTINE c_bands( iter )
!
IF (.NOT. ( dmft .AND. .NOT. dmft_updated ) ) THEN
call diag_bands ( iter, ik, avg_iter )
!sync evc here to allow later use of converged wavefunction on host
!$acc update self(evc)
END IF
!
! ... save wave-functions to be used as input for the
! ... iterative diagonalization of the next scf iteration
! ... and for rho calculation
!
CALL using_evc(0)
IF ( nks > 1 .OR. lelfield ) &
CALL save_buffer ( evc, nwordwfc, iunwfc, ik )
IF ( nks > 1 .OR. lelfield ) CALL save_buffer ( evc, nwordwfc, iunwfc, ik )
!
! ... beware: with pools, if the number of k-points on different
! ... pools differs, make sure that all processors are still in
@ -648,7 +649,6 @@ SUBROUTINE diag_bands( iter, ik, avg_iter )
lrot = ( iter == 1 )
!
IF (.not. use_gpu) THEN
CALL using_evc(1)
IF ( use_para_diag ) THEN
! ! make sure that all processors have the same wfc
CALL pregterg( h_psi, s_psi, okvan, g_psi, &
@ -659,19 +659,20 @@ SUBROUTINE diag_bands( iter, ik, avg_iter )
npw, npwx, nbnd, nbndx, evc, ethr, &
et(1,ik), btype(1,ik), notconv, lrot, dav_iter, nhpsi ) ! BEWARE gstart has been removed from call
END IF
! CALL using_evc(1) done above
!
ELSE
CALL using_evc_d(1)
!$acc host_data use_device(et)
IF ( use_para_diag ) THEN
!$acc host_data use_device(evc)
CALL pregterg_gpu( h_psi_gpu, s_psi_acc, okvan, g_psi_gpu, &
npw, npwx, nbnd, nbndx, evc_d, ethr, &
npw, npwx, nbnd, nbndx, evc, ethr, &
et(1, ik), btype(1,ik), notconv, lrot, dav_iter, nhpsi ) ! BEWARE gstart has been removed from call
!$acc end host_data
!
ELSE
!
CALL regterg ( h_psi_gpu, s_psi_acc, okvan, g_psi_gpu, &
npw, npwx, nbnd, nbndx, evc_d, ethr, &
npw, npwx, nbnd, nbndx, evc, ethr, &
et(1, ik), btype(1,ik), notconv, lrot, dav_iter, nhpsi ) ! BEWARE gstart has been removed from call
END IF
!$acc end host_data
@ -1050,7 +1051,6 @@ SUBROUTINE diag_bands( iter, ik, avg_iter )
lrot = ( iter == 1 )
!
IF (.not. use_gpu ) THEN
CALL using_evc(1)
IF ( use_para_diag ) then
!
CALL pcegterg( h_psi, s_psi, okvan, g_psi, &
@ -1064,19 +1064,19 @@ SUBROUTINE diag_bands( iter, ik, avg_iter )
et(1,ik), btype(1,ik), notconv, lrot, dav_iter, nhpsi )
END IF
ELSE
CALL using_evc_d(1)
!$acc host_data use_device(et)
IF ( use_para_diag ) then
!
!$acc host_data use_device(evc)
CALL pcegterg_gpu( h_psi_gpu, s_psi_acc, okvan, g_psi_gpu, &
npw, npwx, nbnd, nbndx, npol, evc_d, ethr, &
npw, npwx, nbnd, nbndx, npol, evc, ethr, &
et(1, ik), btype(1,ik), notconv, lrot, dav_iter, nhpsi )
!$acc end host_data
!
ELSE
!
CALL cegterg ( h_psi_gpu, s_psi_acc, okvan, g_psi_gpu, &
npw, npwx, nbnd, nbndx, npol, evc_d, ethr, &
npw, npwx, nbnd, nbndx, npol, evc, ethr, &
et(1, ik), btype(1,ik), notconv, lrot, dav_iter, nhpsi )
END IF
!$acc end host_data

View File

@ -34,7 +34,6 @@ SUBROUTINE wfcinit()
USE qexsd_module, ONLY : qexsd_readschema
USE qes_types_module, ONLY : output_type
USE qes_libs_module, ONLY : qes_reset
USE wavefunctions_gpum, ONLY : using_evc
USE uspp_init, ONLY : init_us_2
USE control_flags, ONLY : use_gpu
!
@ -46,7 +45,6 @@ SUBROUTINE wfcinit()
TYPE ( output_type ) :: output_obj
!
CALL start_clock( 'wfcinit' )
CALL using_evc(0) ! this may be removed
!
! ... Orthogonalized atomic functions needed for DFT+U and other cases
!
@ -119,8 +117,8 @@ SUBROUTINE wfcinit()
IF ( nks == 1 ) THEN
INQUIRE (unit = iunwfc, opened = opnd_file)
IF ( .NOT.opnd_file ) CALL diropn( iunwfc, 'wfc', 2*nwordwfc, exst )
CALL using_evc(2)
CALL davcio ( evc, 2*nwordwfc, iunwfc, nks, -1 )
!$acc update device(evc)
IF ( .NOT.opnd_file ) CLOSE ( UNIT=iunwfc, STATUS='keep' )
END IF
END IF
@ -197,9 +195,9 @@ SUBROUTINE wfcinit()
!
! ... write starting wavefunctions to file
!
IF ( nks > 1 .OR. (io_level > 1) .OR. lelfield ) CALL using_evc(0)
IF ( nks > 1 .OR. (io_level > 1) .OR. lelfield ) &
CALL save_buffer ( evc, nwordwfc, iunwfc, ik )
IF ( nks > 1 .OR. (io_level > 1) .OR. lelfield ) THEN
CALL save_buffer ( evc, nwordwfc, iunwfc, ik )
END IF
!
END DO
!
@ -241,7 +239,6 @@ SUBROUTINE init_wfc ( ik )
USE mp, ONLY : mp_bcast
USE xc_lib, ONLY : xclib_dft_is, stop_exx
!
USE wavefunctions_gpum, ONLY : using_evc, using_evc_d, evc_d
USE control_flags, ONLY : lscf, use_gpu
!
IMPLICIT NONE
@ -419,13 +416,12 @@ SUBROUTINE init_wfc ( ik )
IF ( xclib_dft_is('hybrid') .and. lscf ) CALL stop_exx()
CALL start_clock( 'wfcinit:wfcrot' ); !write(*,*) 'start wfcinit:wfcrot' ; FLUSH(6)
IF(use_gpu) THEN
CALL using_evc_d(2) ! rotate_wfc_gpu (..., evc_d, etatom_d) -> evc : out (not specified)
!$acc host_data use_device(wfcatom,etatom)
CALL rotate_wfc_gpu ( npwx, ngk_ik, n_starting_wfc, gstart, nbnd, wfcatom, npol, okvan, evc_d, etatom )
!$acc host_data use_device(wfcatom,etatom,evc)
CALL rotate_wfc_gpu ( npwx, ngk_ik, n_starting_wfc, gstart, nbnd, wfcatom, npol, okvan, evc, etatom )
!$acc end host_data
!$acc update self(evc)
ELSE
CALL rotate_wfc ( npwx, ngk(ik), n_starting_wfc, gstart, nbnd, wfcatom, npol, okvan, evc, etatom )
CALL using_evc(1) ! rotate_wfc (..., evc, etatom) -> evc : out (not specified)
END IF
CALL stop_clock( 'wfcinit:wfcrot' ); !write(*,*) 'stop wfcinit:wfcrot' ; FLUSH(6)
!