Optimize hpsi_dot_v

This commit is contained in:
Ye Luo 2018-05-28 19:13:39 -05:00
parent 85f6e070d9
commit 14508b0810
2 changed files with 115 additions and 37 deletions

View File

@ -634,7 +634,9 @@ SUBROUTINE pcegterg(h_psi, s_psi, uspp, g_psi, &
ortho_parent_comm, ortho_cntx, do_distr_diag_inside_bgrp
USE descriptors, ONLY : la_descriptor, descla_init , descla_local_dims
USE parallel_toolkit, ONLY : zsqmred, zsqmher, zsqmdst
USE mp, ONLY : mp_bcast, mp_root_sum, mp_sum, mp_barrier
USE mp, ONLY : mp_bcast, mp_root_sum, mp_sum, mp_barrier, &
mp_size, mp_type_create_row, mp_type_free, &
mp_allgather
!
IMPLICIT NONE
!
@ -1239,72 +1241,82 @@ CONTAINS
!
INTEGER :: ipc, ipr
INTEGER :: nr, ir, ic, notcl, root, np, ipol, ig
INTEGER :: ortho_parent_comm_size
COMPLEX(DP), ALLOCATABLE :: vtmp( :, : )
COMPLEX(DP), ALLOCATABLE :: ptmp( :, :, : )
COMPLEX(DP) :: beta
INTEGER :: row_type
INTEGER, ALLOCATABLE :: counts(:), displs(:)
ALLOCATE( vtmp( nx, nx ) )
ALLOCATE( ptmp( npwx, npol, nx ) )
ortho_parent_comm_size = mp_size(ortho_parent_comm)
!
ALLOCATE( vtmp( nvecx, nx ) )
ALLOCATE( counts(ortho_parent_comm_size), displs(ortho_parent_comm_size) )
DO ipc = 1, desc%npc
!
IF( notcnv_ip( ipc ) > 0 ) THEN
notcl = notcnv_ip( ipc )
ic = ic_notcnv( ipc )
beta = ZERO
ic = ic_notcnv( ipc )
!
counts = 0
!
DO ipr = 1, desc%npr
!
nr = nrc_ip( ipr )
ir = irc_ip( ipr )
!
root = rank_ip( ipr, ipc )
counts(root+1) = nr
IF( ipr-1 == desc%myr .AND. ipc-1 == desc%myc .AND. la_proc ) THEN
vtmp(:,1:notcl) = vl(:,1:notcl)
END IF
CALL mp_bcast( vtmp(:,1:notcl), root, ortho_parent_comm )
!
IF ( uspp ) THEN
!
CALL ZGEMM( 'N', 'N', kdim, notcl, nr, ONE, &
spsi( 1, 1, ir ), kdmx, vtmp, nx, beta, psi(1,1,nb1+ic-1), kdmx )
!
ELSE
!
CALL ZGEMM( 'N', 'N', kdim, notcl, nr, ONE, &
psi( 1, 1, ir ), kdmx, vtmp, nx, beta, psi(1,1,nb1+ic-1), kdmx )
!
vtmp(ir:ir+nr-1,1:notcl) = vl(:,1:notcl)
END IF
ENDDO
!
displs(1) = 0
!
DO np = 1, ortho_parent_comm_size - 1
displs(np+1) = displs(np) + counts(np)
ENDDO
!
CALL mp_type_create_row(vtmp(1,1), notcl, nvecx, row_type)
!
CALL mp_allgather(vtmp, row_type, counts, displs, ortho_parent_comm)
!
CALL mp_type_free(row_type)
!
IF ( uspp ) THEN
!
CALL ZGEMM( 'N', 'N', kdim, notcl, nr, ONE, &
hpsi( 1, 1, ir ), kdmx, vtmp, nx, beta, ptmp, kdmx )
beta = ONE
END DO
CALL ZGEMM( 'N', 'N', kdim, notcl, nbase, ONE, &
spsi, kdmx, vtmp, nvecx, ZERO, psi(1,1,nb1+ic-1), kdmx )
!
ELSE
!
CALL ZGEMM( 'N', 'N', kdim, notcl, nbase, ONE, &
psi, kdmx, vtmp, nvecx, ZERO, psi(1,1,nb1+ic-1), kdmx )
!
END IF
!
!$omp parallel do collapse(3)
DO np = 1, notcl
DO ipol = 1, npol
DO ig = 1, npwx
!
psi(ig,ipol,nbase+np+ic-1) = ptmp(ig,ipol,np) - ew(nbase+np+ic-1) * psi(ig,ipol,nbase+np+ic-1)
psi(ig,ipol,nbase+np+ic-1) = ew(nbase+np+ic-1) * psi(ig,ipol,nbase+np+ic-1)
!
END DO
END DO
END DO
!$omp end parallel do
!
CALL ZGEMM( 'N', 'N', kdim, notcl, nbase, ONE, &
hpsi, kdmx, vtmp, nvecx, -ONE, psi(1,1,nb1+ic-1), kdmx )
!
END IF
!
END DO
DEALLOCATE( vtmp )
DEALLOCATE( ptmp )
DEALLOCATE( displs )
DEALLOCATE( counts )
RETURN
END SUBROUTINE hpsi_dot_v

View File

@ -25,7 +25,8 @@
mp_group_create, mp_comm_split, mp_set_displs, &
mp_circular_shift_left, &
mp_get_comm_null, mp_get_comm_self, mp_count_nodes, &
mp_type_create_column_section, mp_type_free
mp_type_create_column_section, mp_type_free, &
mp_type_create_row, mp_allgather
!
INTERFACE mp_bcast
@ -72,6 +73,10 @@
mp_gatherv_inplace_cplx_column_section
END INTERFACE
INTERFACE mp_allgather
MODULE PROCEDURE mp_allgatherv_inplace_cplx_row
END INTERFACE
INTERFACE mp_alltoall
MODULE PROCEDURE mp_alltoall_c3d, mp_alltoall_i3d
END INTERFACE
@ -88,6 +93,10 @@
MODULE PROCEDURE mp_type_create_cplx_column_section
END INTERFACE
INTERFACE mp_type_create_row
MODULE PROCEDURE mp_type_create_cplx_row
END INTERFACE
!------------------------------------------------------------------------------!
!
CONTAINS
@ -2005,6 +2014,33 @@
RETURN
END SUBROUTINE mp_gatherv_inplace_cplx_column_section
!------------------------------------------------------------------------------!
!..mp_allgatherv_inplace_cplx_row
!..Ye Luo
SUBROUTINE mp_allgatherv_inplace_cplx_row(alldata, my_row_type, recvcount, displs, gid)
IMPLICIT NONE
COMPLEX(DP) :: alldata(:,:)
INTEGER, INTENT(IN) :: my_row_type
INTEGER, INTENT(IN) :: recvcount(:), displs(:)
INTEGER, INTENT(IN) :: gid
INTEGER :: ierr, npe, myid
#if defined (__MPI)
CALL mpi_comm_size( gid, npe, ierr )
IF (ierr/=0) CALL mp_stop( 8069 )
CALL mpi_comm_rank( gid, myid, ierr )
IF (ierr/=0) CALL mp_stop( 8070 )
!
IF ( SIZE( recvcount ) < npe .OR. SIZE( displs ) < npe ) CALL mp_stop( 8071 )
!
CALL MPI_ALLGATHERV( MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, &
alldata, recvcount, displs, my_row_type, gid, ierr )
IF (ierr/=0) CALL mp_stop( 8074 )
#endif
RETURN
END SUBROUTINE mp_allgatherv_inplace_cplx_row
!------------------------------------------------------------------------------!
SUBROUTINE mp_set_displs( recvcount, displs, ntot, nproc )
@ -2426,6 +2462,36 @@ SUBROUTINE mp_type_create_cplx_column_section(dummy, start, length, stride, myty
RETURN
END SUBROUTINE mp_type_create_cplx_column_section
SUBROUTINE mp_type_create_cplx_row(dummy, ncols, stride, mytype)
IMPLICIT NONE
!
COMPLEX (DP), INTENT(IN) :: dummy
INTEGER, INTENT(IN) :: ncols, stride
INTEGER, INTENT(OUT) :: mytype
!
#if defined(__MPI)
INTEGER :: ierr
INTEGER column_type
INTEGER (KIND=MPI_ADDRESS_KIND) lb, sizeofcplx
!
CALL MPI_TYPE_GET_EXTENT(MPI_DOUBLE_COMPLEX, lb, sizeofcplx, ierr)
IF (ierr/=0) CALL mp_stop( 8081 )
CALL MPI_TYPE_VECTOR(ncols, 1, stride, MPI_DOUBLE_COMPLEX, column_type, ierr);
IF (ierr/=0) CALL mp_stop( 8082 )
CALL MPI_TYPE_COMMIT(column_type, ierr);
IF (ierr/=0) CALL mp_stop( 8083 )
CALL MPI_TYPE_CREATE_RESIZED(column_type, 0, sizeofcplx, mytype, ierr);
IF (ierr/=0) CALL mp_stop( 8084 )
CALL MPI_TYPE_COMMIT(mytype, ierr);
IF (ierr/=0) CALL mp_stop( 8085 )
CALL mp_type_free(column_type);
#else
mytype = 0;
#endif
!
RETURN
END SUBROUTINE mp_type_create_cplx_row
SUBROUTINE mp_type_free(mytype)
IMPLICIT NONE
INTEGER :: mytype, ierr