Remove all unnecessary mem ops in cegterg.

This commit is contained in:
Ye Luo 2018-05-27 21:54:46 -05:00
parent 0f340dd372
commit 2c6c859896
3 changed files with 251 additions and 45 deletions

View File

@ -68,7 +68,8 @@ SUBROUTINE cegterg( h_psi, s_psi, uspp, g_psi, &
USE david_param, ONLY : DP
USE mp_bands_util, ONLY : intra_bgrp_comm, inter_bgrp_comm, root_bgrp_id,&
nbgrp, my_bgrp_id
USE mp, ONLY : mp_sum, mp_bcast
USE mp, ONLY : mp_sum, mp_allgather, mp_bcast, mp_size,&
mp_type_create_column_section, mp_type_free
!
IMPLICIT NONE
!
@ -109,6 +110,8 @@ SUBROUTINE cegterg( h_psi, s_psi, uspp, g_psi, &
! adapted npw and npwx
! do-loop counters
INTEGER :: n_start, n_end, my_n
INTEGER :: column_section_type
! defines a column section for communication
INTEGER :: ierr
COMPLEX(DP), ALLOCATABLE :: hc(:,:), sc(:,:), vc(:,:)
! Hamiltonian on the reduced basis
@ -124,6 +127,8 @@ SUBROUTINE cegterg( h_psi, s_psi, uspp, g_psi, &
! true if the root is converged
REAL(DP) :: empty_ethr
! threshold for empty bands
INTEGER, ALLOCATABLE :: recv_counts(:), displs(:)
! receive counts and memory offsets
!
REAL(DP), EXTERNAL :: ddot
!
@ -185,16 +190,14 @@ SUBROUTINE cegterg( h_psi, s_psi, uspp, g_psi, &
ALLOCATE( conv( nvec ), STAT=ierr )
IF( ierr /= 0 ) &
CALL errore( ' cegterg ',' cannot allocate conv ', ABS(ierr) )
ALLOCATE( recv_counts(mp_size(inter_bgrp_comm)), displs(mp_size(inter_bgrp_comm)) )
!
notcnv = nvec
nbase = nvec
conv = .FALSE.
!
!$omp parallel
IF ( uspp ) CALL threaded_fill_value_nowait(spsi, nvecx*npol*npwx, ZERO)
CALL threaded_fill_value_nowait(hpsi, nvecx*npol*npwx, ZERO)
CALL threaded_fill_array_nowait(psi, nvec*npol*npwx, evc)
CALL threaded_fill_value_nowait(psi(1,1,nvec+1), (nvecx-nvec)*npol*npwx, ZERO)
!$omp end parallel
!
! ... hpsi contains h times the basis vectors
@ -209,17 +212,16 @@ SUBROUTINE cegterg( h_psi, s_psi, uspp, g_psi, &
! ... space vc contains the eigenvectors of hc
!
CALL start_clock( 'cegterg:init' )
hc(:,:) = ZERO
sc(:,:) = ZERO
vc(:,:) = ZERO
!
CALL divide(inter_bgrp_comm,nbase,n_start,n_end)
CALL divide_all(inter_bgrp_comm,nbase,n_start,n_end,recv_counts,displs)
CALL mp_type_create_column_section(sc(1,1), 0, nbase, nvecx, column_section_type)
my_n = n_end - n_start + 1; !write (*,*) nbase,n_start,n_end
!
if (n_start .le. n_end) &
CALL ZGEMM( 'C','N', nbase, my_n, kdim, ONE, psi, kdmx, hpsi(1,1,n_start), kdmx, ZERO, hc(1,n_start), nvecx )
CALL mp_sum( hc( :, 1:nbase ), inter_bgrp_comm )
!
CALL mp_sum( hc( :, 1:nbase ), intra_bgrp_comm )
if (n_start .le. n_end) CALL mp_sum( hc( 1:nbase, n_start:n_end ), intra_bgrp_comm )
CALL mp_allgather( hc(1,1), column_section_type, recv_counts, displs, inter_bgrp_comm )
!
IF ( uspp ) THEN
!
@ -234,12 +236,33 @@ SUBROUTINE cegterg( h_psi, s_psi, uspp, g_psi, &
ZERO, sc(1,n_start), nvecx )
!
END IF
CALL mp_sum( sc( :, 1:nbase ), inter_bgrp_comm )
!
CALL mp_sum( sc( :, 1:nbase ), intra_bgrp_comm )
if (n_start .le. n_end) CALL mp_sum( sc( 1:nbase, n_start:n_end ), intra_bgrp_comm )
CALL mp_allgather( sc(1,1), column_section_type, recv_counts, displs, inter_bgrp_comm )
!
CALL mp_type_free( column_section_type )
!
DO n = 1, nbase
!
! ... the diagonal of hc and sc must be strictly real
!
hc(n,n) = CMPLX( REAL( hc(n,n) ), 0.D0 ,kind=DP)
sc(n,n) = CMPLX( REAL( sc(n,n) ), 0.D0 ,kind=DP)
!
DO m = n + 1, nbase
!
hc(n,m) = CONJG( hc(m,n) )
sc(n,m) = CONJG( sc(m,n) )
!
END DO
!
END DO
!
CALL stop_clock( 'cegterg:init' )
!
IF ( lrot ) THEN
!
vc(1:nbase,1:nbase) = ZERO
!
DO n = 1, nbase
!
@ -305,7 +328,6 @@ SUBROUTINE cegterg( h_psi, s_psi, uspp, g_psi, &
!
CALL divide(inter_bgrp_comm,nbase,n_start,n_end)
my_n = n_end - n_start + 1; !write (*,*) nbase,n_start,n_end
psi(:,:,nb1:nbase+notcnv)=ZERO
IF ( uspp ) THEN
!
if (n_start .le. n_end) &
@ -379,32 +401,34 @@ SUBROUTINE cegterg( h_psi, s_psi, uspp, g_psi, &
!
CALL start_clock( 'cegterg:overlap' )
!
hc( :, nb1:nb1+notcnv-1 )=ZERO
CALL divide(inter_bgrp_comm,nbase+notcnv,n_start,n_end)
CALL divide_all(inter_bgrp_comm,nbase+notcnv,n_start,n_end,recv_counts,displs)
CALL mp_type_create_column_section(sc(1,1), nbase, notcnv, nvecx, column_section_type)
my_n = n_end - n_start + 1; !write (*,*) nbase+notcnv,n_start,n_end
CALL ZGEMM( 'C','N', my_n, notcnv, kdim, ONE, psi(1,1,n_start), kdmx, hpsi(1,1,nb1), kdmx, &
ZERO, hc(n_start,nb1), nvecx )
CALL mp_sum( hc( :, nb1:nb1+notcnv-1 ), inter_bgrp_comm )
!
CALL mp_sum( hc( :, nb1:nb1+notcnv-1 ), intra_bgrp_comm )
CALL ZGEMM( 'C','N', notcnv, my_n, kdim, ONE, hpsi(1,1,nb1), kdmx, psi(1,1,n_start), kdmx, &
ZERO, hc(nb1,n_start), nvecx )
!
if (n_start .le. n_end) CALL mp_sum( hc( nb1:nbase+notcnv, n_start:n_end ), intra_bgrp_comm )
CALL mp_allgather( hc(1,1), column_section_type, recv_counts, displs, inter_bgrp_comm )
!
sc( :, nb1:nb1+notcnv-1 )=ZERO
CALL divide(inter_bgrp_comm,nbase+notcnv,n_start,n_end)
my_n = n_end - n_start + 1; !write (*,*) nbase+notcnv,n_start,n_end
IF ( uspp ) THEN
!
CALL ZGEMM( 'C','N', my_n, notcnv, kdim, ONE, psi(1,1,n_start), kdmx, spsi(1,1,nb1), kdmx, &
ZERO, sc(n_start,nb1), nvecx )
CALL ZGEMM( 'C','N', notcnv, my_n, kdim, ONE, spsi(1,1,nb1), kdmx, psi(1,1,n_start), kdmx, &
ZERO, sc(nb1,n_start), nvecx )
!
ELSE
!
CALL ZGEMM( 'C','N', my_n, notcnv, kdim, ONE, psi(1,1,n_start), kdmx, psi(1,1,nb1), kdmx, &
ZERO, sc(n_start,nb1), nvecx )
CALL ZGEMM( 'C','N', notcnv, my_n, kdim, ONE, psi(1,1,nb1), kdmx, psi(1,1,n_start), kdmx, &
ZERO, sc(nb1,n_start), nvecx )
!
END IF
CALL mp_sum( sc( :, nb1:nb1+notcnv-1 ), inter_bgrp_comm )
!
CALL mp_sum( sc( :, nb1:nb1+notcnv-1 ), intra_bgrp_comm )
if (n_start .le. n_end) CALL mp_sum( sc( nb1:nbase+notcnv, n_start:n_end ), intra_bgrp_comm )
CALL mp_allgather( sc(1,1), column_section_type, recv_counts, displs, inter_bgrp_comm )
!
CALL mp_type_free( column_section_type )
!
CALL stop_clock( 'cegterg:overlap' )
!
@ -412,15 +436,17 @@ SUBROUTINE cegterg( h_psi, s_psi, uspp, g_psi, &
!
DO n = 1, nbase
!
! ... the diagonal of hc and sc must be strictly real
! ... the diagonal of hc and sc must be strictly real
!
hc(n,n) = CMPLX( REAL( hc(n,n) ), 0.D0 ,kind=DP)
sc(n,n) = CMPLX( REAL( sc(n,n) ), 0.D0 ,kind=DP)
IF( n>=nb1 ) THEN
hc(n,n) = CMPLX( REAL( hc(n,n) ), 0.D0 ,kind=DP)
sc(n,n) = CMPLX( REAL( sc(n,n) ), 0.D0 ,kind=DP)
ENDIF
!
DO m = n + 1, nbase
DO m = MAX(n+1,nb1), nbase
!
hc(m,n) = CONJG( hc(n,m) )
sc(m,n) = CONJG( sc(n,m) )
hc(n,m) = CONJG( hc(m,n) )
sc(n,m) = CONJG( sc(m,n) )
!
END DO
!
@ -467,7 +493,6 @@ SUBROUTINE cegterg( h_psi, s_psi, uspp, g_psi, &
!
CALL start_clock( 'cegterg:last' )
!
evc = ZERO
CALL divide(inter_bgrp_comm,nbase,n_start,n_end)
my_n = n_end - n_start + 1; !write (*,*) nbase,n_start,n_end
CALL ZGEMM( 'N','N', kdim, nvec, my_n, ONE, psi(1,1,n_start), kdmx, vc(n_start,1), nvecx, &
@ -501,33 +526,28 @@ SUBROUTINE cegterg( h_psi, s_psi, uspp, g_psi, &
!
IF ( uspp ) THEN
!
psi(:,:,nvec+1:nvec+nvec) = ZERO
CALL ZGEMM( 'N','N', kdim, nvec, my_n, ONE, spsi(1,1,n_start), kdmx, vc(n_start,1), nvecx, &
ZERO, psi(1,1,nvec+1), kdmx)
CALL mp_sum( psi(:,:,nvec+1:nvec+nvec), inter_bgrp_comm )
!
spsi(:,:,1:nvec) = psi(:,:,nvec+1:nvec+nvec)
CALL mp_sum( spsi(:,:,1:nvec), inter_bgrp_comm )
!
END IF
!
psi(:,:,nvec+1:nvec+nvec) = ZERO
CALL ZGEMM( 'N','N', kdim, nvec, my_n, ONE, hpsi(1,1,n_start), kdmx, vc(n_start,1), nvecx, &
ZERO, psi(1,1,nvec+1), kdmx )
CALL mp_sum( psi(:,:,nvec+1:nvec+nvec), inter_bgrp_comm )
!
hpsi(:,:,1:nvec) = psi(:,:,nvec+1:nvec+nvec)
CALL mp_sum( hpsi(:,:,1:nvec), inter_bgrp_comm )
!
! ... refresh the reduced hamiltonian
!
nbase = nvec
!
hc(:,1:nbase) = ZERO
sc(:,1:nbase) = ZERO
vc(:,1:nbase) = ZERO
hc(1:nbase,1:nbase) = ZERO
sc(1:nbase,1:nbase) = ZERO
vc(1:nbase,1:nbase) = ZERO
!
DO n = 1, nbase
!
! hc(n,n) = REAL( e(n) )
hc(n,n) = CMPLX( e(n), 0.0_DP ,kind=DP)
!
sc(n,n) = ONE
@ -541,6 +561,8 @@ SUBROUTINE cegterg( h_psi, s_psi, uspp, g_psi, &
!
END DO iterate
!
DEALLOCATE( recv_counts )
DEALLOCATE( displs )
DEALLOCATE( conv )
DEALLOCATE( ew )
DEALLOCATE( vc )

View File

@ -51,3 +51,56 @@ SUBROUTINE divide (comm, ntodiv, startn, lastn)
RETURN
END SUBROUTINE divide
SUBROUTINE divide_all (comm, ntodiv, startn, lastn, counts, displs)
!-----------------------------------------------------------------------
!
! Given "ntodiv" objects, distribute index across a group of processors
! belonging to communicator "comm"
! Each processor gets index from "startn" to "lastn"
! If the number of processors nproc exceeds the number of objects,
! the last nproc-ntodiv processors return startn = ntodiv+1 > lastn = ntodiv
!
USE mp, ONLY : mp_size, mp_rank
IMPLICIT NONE
!
INTEGER, INTENT(in) :: comm
! communicator
INTEGER, INTENT(in) :: ntodiv
! index to be distributed
INTEGER, INTENT(out):: startn, lastn
! indices for this processor: from startn to lastn
INTEGER, INTENT(out):: counts(*), displs(*)
! indice counts and displacements of all ranks
!
INTEGER :: me_comm, nproc_comm
! identifier of current processor
! number of processors
!
INTEGER :: ndiv, rest
! number of points per processor
! number of processors having one more points
INTEGER :: ip
!
nproc_comm = mp_size(comm)
me_comm = mp_rank(comm)
!
rest = mod ( ntodiv, nproc_comm )
ndiv = int( ntodiv / nproc_comm )
!
DO ip = 1, nproc_comm
IF (rest >= ip) THEN
counts(ip) = ndiv + 1
displs(ip) = (ip-1) * (ndiv+1)
ELSE
counts(ip) = ndiv
displs(ip) = (ip-1) * ndiv + rest
ENDIF
ENDDO
! seting startn and lastn
startn = displs(me_comm+1) + 1
lastn = displs(me_comm+1) + counts(me_comm+1)
RETURN
END SUBROUTINE divide_all

View File

@ -19,11 +19,13 @@
PUBLIC :: mp_start, mp_abort, mp_stop, mp_end, &
mp_bcast, mp_sum, mp_max, mp_min, mp_rank, mp_size, &
mp_gather, mp_alltoall, mp_get, mp_put, mp_barrier, mp_report, mp_group_free, &
mp_gather, mp_allgather, mp_alltoall, mp_get, mp_put, &
mp_barrier, mp_report, mp_group_free, &
mp_root_sum, mp_comm_free, mp_comm_create, mp_comm_group, &
mp_group_create, mp_comm_split, mp_set_displs, &
mp_circular_shift_left, &
mp_get_comm_null, mp_get_comm_self, mp_count_nodes
mp_get_comm_null, mp_get_comm_self, mp_count_nodes, &
mp_type_create_column_section, mp_type_free
!
INTERFACE mp_bcast
@ -59,16 +61,25 @@
INTERFACE mp_max
MODULE PROCEDURE mp_max_i, mp_max_r, mp_max_rv, mp_max_iv
END INTERFACE
INTERFACE mp_min
MODULE PROCEDURE mp_min_i, mp_min_r, mp_min_rv, mp_min_iv
END INTERFACE
INTERFACE mp_gather
MODULE PROCEDURE mp_gather_i1, mp_gather_iv, mp_gatherv_rv, mp_gatherv_iv, &
mp_gatherv_rm, mp_gatherv_im, mp_gatherv_cv
END INTERFACE
INTERFACE mp_allgather
MODULE PROCEDURE mp_allgatherv_inplace_c1dv, mp_allgatherv_inplace_c2dv, &
mp_allgatherv_inplace_cplx_column_section
END INTERFACE
INTERFACE mp_alltoall
MODULE PROCEDURE mp_alltoall_c3d, mp_alltoall_i3d
END INTERFACE
INTERFACE mp_circular_shift_left
MODULE PROCEDURE mp_circular_shift_left_i0, &
mp_circular_shift_left_i1, &
@ -77,6 +88,10 @@
mp_circular_shift_left_c2d
END INTERFACE
INTERFACE mp_type_create_column_section
MODULE PROCEDURE mp_type_create_cplx_column_section
END INTERFACE
!------------------------------------------------------------------------------!
!
CONTAINS
@ -1962,6 +1977,89 @@
END SUBROUTINE mp_gatherv_im
!------------------------------------------------------------------------------!
!..mp_allgatherv_inplace_cv
!..Ye Luo
SUBROUTINE mp_allgatherv_inplace_cv_any(alldata, my_cplx_type, recvcount, displs, gid)
IMPLICIT NONE
COMPLEX(DP) :: alldata(*)
INTEGER, INTENT(IN) :: my_cplx_type
INTEGER, INTENT(IN) :: recvcount(:), displs(:)
INTEGER, INTENT(IN) :: gid
INTEGER :: ierr
#if defined (__MPI)
CALL MPI_ALLGATHERV( MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, &
alldata, recvcount, displs, my_cplx_type, gid, ierr )
IF (ierr/=0) CALL mp_stop( 8074 )
#endif
RETURN
END SUBROUTINE mp_allgatherv_inplace_cv_any
SUBROUTINE mp_allgatherv_inplace_c1dv(alldata, recvcount, displs, gid)
IMPLICIT NONE
COMPLEX(DP) :: alldata(:), dummy
INTEGER, INTENT(IN) :: recvcount(:), displs(:)
INTEGER, INTENT(IN) :: gid
INTEGER :: ierr, npe, myid
#if defined (__MPI)
CALL mpi_comm_size( gid, npe, ierr )
IF (ierr/=0) CALL mp_stop( 8069 )
CALL mpi_comm_rank( gid, myid, ierr )
IF (ierr/=0) CALL mp_stop( 8070 )
!
IF ( SIZE( recvcount ) < npe .OR. SIZE( displs ) < npe ) CALL mp_stop( 8071 )
!
IF ( SIZE( alldata ) < displs( npe ) + recvcount( npe ) ) CALL mp_stop( 8072 )
CALL mp_allgatherv_inplace_cv_any(alldata, MPI_DOUBLE_COMPLEX, recvcount, displs, gid)
#endif
RETURN
END SUBROUTINE mp_allgatherv_inplace_c1dv
SUBROUTINE mp_allgatherv_inplace_c2dv(alldata, recvcount, displs, gid)
IMPLICIT NONE
COMPLEX(DP) :: alldata(:,:)
INTEGER, INTENT(IN) :: recvcount(:), displs(:)
INTEGER, INTENT(IN) :: gid
INTEGER :: ierr, npe, myid
#if defined (__MPI)
CALL mpi_comm_size( gid, npe, ierr )
IF (ierr/=0) CALL mp_stop( 8069 )
CALL mpi_comm_rank( gid, myid, ierr )
IF (ierr/=0) CALL mp_stop( 8070 )
!
IF ( SIZE( recvcount ) < npe .OR. SIZE( displs ) < npe ) CALL mp_stop( 8071 )
!
IF ( SIZE( alldata ) < displs( npe ) + recvcount( npe ) ) CALL mp_stop( 8072 )
CALL mp_allgatherv_inplace_cv_any(alldata, MPI_DOUBLE_COMPLEX, recvcount, displs, gid)
#endif
RETURN
END SUBROUTINE mp_allgatherv_inplace_c2dv
SUBROUTINE mp_allgatherv_inplace_cplx_column_section(alldata, my_column_type, recvcount, displs, gid)
IMPLICIT NONE
COMPLEX(DP) :: alldata
INTEGER, INTENT(IN) :: my_column_type
INTEGER, INTENT(IN) :: recvcount(:), displs(:)
INTEGER, INTENT(IN) :: gid
INTEGER :: ierr, npe, myid
#if defined (__MPI)
CALL mpi_comm_size( gid, npe, ierr )
IF (ierr/=0) CALL mp_stop( 8069 )
CALL mpi_comm_rank( gid, myid, ierr )
IF (ierr/=0) CALL mp_stop( 8070 )
!
IF ( SIZE( recvcount ) < npe .OR. SIZE( displs ) < npe ) CALL mp_stop( 8071 )
!
CALL mp_allgatherv_inplace_cv_any(alldata, my_column_type, recvcount, displs, gid)
#endif
RETURN
END SUBROUTINE mp_allgatherv_inplace_cplx_column_section
!------------------------------------------------------------------------------!
SUBROUTINE mp_set_displs( recvcount, displs, ntot, nproc )
@ -2361,6 +2459,39 @@ FUNCTION mp_get_comm_self( )
mp_get_comm_self = MPI_COMM_SELF
END FUNCTION mp_get_comm_self
SUBROUTINE mp_type_create_cplx_column_section(dummy, start, length, stride, mytype)
IMPLICIT NONE
!
COMPLEX (DP), INTENT(IN) :: dummy
INTEGER, INTENT(IN) :: start, length, stride
INTEGER, INTENT(OUT) :: mytype
!
#if defined(__MPI)
INTEGER :: ierr
!
CALL MPI_TYPE_CREATE_SUBARRAY(1, stride, length, start, MPI_ORDER_FORTRAN,&
MPI_DOUBLE_COMPLEX, mytype, ierr)
IF (ierr/=0) CALL mp_stop( 8081 )
CALL MPI_Type_commit(mytype, ierr)
IF (ierr/=0) CALL mp_stop( 8082 )
#else
mytype = 0;
#endif
!
RETURN
END SUBROUTINE mp_type_create_cplx_column_section
SUBROUTINE mp_type_free(mytype)
IMPLICIT NONE
INTEGER :: mytype, ierr
!
#if defined(__MPI)
CALL MPI_TYPE_FREE(mytype, ierr)
IF (ierr/=0) CALL mp_stop( 8083 )
#endif
!
RETURN
END SUBROUTINE mp_type_free
!------------------------------------------------------------------------------!
END MODULE mp
!------------------------------------------------------------------------------!