mirror of https://gitlab.com/QEF/q-e.git
Optimize hpsi_dot_v
This commit is contained in:
parent
85f6e070d9
commit
14508b0810
|
@ -634,7 +634,9 @@ SUBROUTINE pcegterg(h_psi, s_psi, uspp, g_psi, &
|
|||
ortho_parent_comm, ortho_cntx, do_distr_diag_inside_bgrp
|
||||
USE descriptors, ONLY : la_descriptor, descla_init , descla_local_dims
|
||||
USE parallel_toolkit, ONLY : zsqmred, zsqmher, zsqmdst
|
||||
USE mp, ONLY : mp_bcast, mp_root_sum, mp_sum, mp_barrier
|
||||
USE mp, ONLY : mp_bcast, mp_root_sum, mp_sum, mp_barrier, &
|
||||
mp_size, mp_type_create_row, mp_type_free, &
|
||||
mp_allgather
|
||||
!
|
||||
IMPLICIT NONE
|
||||
!
|
||||
|
@ -1239,12 +1241,15 @@ CONTAINS
|
|||
!
|
||||
INTEGER :: ipc, ipr
|
||||
INTEGER :: nr, ir, ic, notcl, root, np, ipol, ig
|
||||
INTEGER :: ortho_parent_comm_size
|
||||
COMPLEX(DP), ALLOCATABLE :: vtmp( :, : )
|
||||
COMPLEX(DP), ALLOCATABLE :: ptmp( :, :, : )
|
||||
COMPLEX(DP) :: beta
|
||||
INTEGER :: row_type
|
||||
INTEGER, ALLOCATABLE :: counts(:), displs(:)
|
||||
|
||||
ALLOCATE( vtmp( nx, nx ) )
|
||||
ALLOCATE( ptmp( npwx, npol, nx ) )
|
||||
ortho_parent_comm_size = mp_size(ortho_parent_comm)
|
||||
!
|
||||
ALLOCATE( vtmp( nvecx, nx ) )
|
||||
ALLOCATE( counts(ortho_parent_comm_size), displs(ortho_parent_comm_size) )
|
||||
|
||||
DO ipc = 1, desc%npc
|
||||
!
|
||||
|
@ -1252,59 +1257,66 @@ CONTAINS
|
|||
|
||||
notcl = notcnv_ip( ipc )
|
||||
ic = ic_notcnv( ipc )
|
||||
|
||||
beta = ZERO
|
||||
|
||||
DO ipr = 1, desc%npr
|
||||
!
|
||||
counts = 0
|
||||
!
|
||||
DO ipr = 1, desc%npr
|
||||
nr = nrc_ip( ipr )
|
||||
ir = irc_ip( ipr )
|
||||
!
|
||||
root = rank_ip( ipr, ipc )
|
||||
|
||||
counts(root+1) = nr
|
||||
IF( ipr-1 == desc%myr .AND. ipc-1 == desc%myc .AND. la_proc ) THEN
|
||||
vtmp(:,1:notcl) = vl(:,1:notcl)
|
||||
vtmp(ir:ir+nr-1,1:notcl) = vl(:,1:notcl)
|
||||
END IF
|
||||
|
||||
CALL mp_bcast( vtmp(:,1:notcl), root, ortho_parent_comm )
|
||||
ENDDO
|
||||
!
|
||||
displs(1) = 0
|
||||
!
|
||||
DO np = 1, ortho_parent_comm_size - 1
|
||||
displs(np+1) = displs(np) + counts(np)
|
||||
ENDDO
|
||||
!
|
||||
CALL mp_type_create_row(vtmp(1,1), notcl, nvecx, row_type)
|
||||
!
|
||||
CALL mp_allgather(vtmp, row_type, counts, displs, ortho_parent_comm)
|
||||
!
|
||||
CALL mp_type_free(row_type)
|
||||
!
|
||||
IF ( uspp ) THEN
|
||||
!
|
||||
CALL ZGEMM( 'N', 'N', kdim, notcl, nr, ONE, &
|
||||
spsi( 1, 1, ir ), kdmx, vtmp, nx, beta, psi(1,1,nb1+ic-1), kdmx )
|
||||
CALL ZGEMM( 'N', 'N', kdim, notcl, nbase, ONE, &
|
||||
spsi, kdmx, vtmp, nvecx, ZERO, psi(1,1,nb1+ic-1), kdmx )
|
||||
!
|
||||
ELSE
|
||||
!
|
||||
CALL ZGEMM( 'N', 'N', kdim, notcl, nr, ONE, &
|
||||
psi( 1, 1, ir ), kdmx, vtmp, nx, beta, psi(1,1,nb1+ic-1), kdmx )
|
||||
CALL ZGEMM( 'N', 'N', kdim, notcl, nbase, ONE, &
|
||||
psi, kdmx, vtmp, nvecx, ZERO, psi(1,1,nb1+ic-1), kdmx )
|
||||
!
|
||||
END IF
|
||||
!
|
||||
CALL ZGEMM( 'N', 'N', kdim, notcl, nr, ONE, &
|
||||
hpsi( 1, 1, ir ), kdmx, vtmp, nx, beta, ptmp, kdmx )
|
||||
|
||||
beta = ONE
|
||||
|
||||
END DO
|
||||
|
||||
!$omp parallel do collapse(3)
|
||||
DO np = 1, notcl
|
||||
DO ipol = 1, npol
|
||||
DO ig = 1, npwx
|
||||
!
|
||||
psi(ig,ipol,nbase+np+ic-1) = ptmp(ig,ipol,np) - ew(nbase+np+ic-1) * psi(ig,ipol,nbase+np+ic-1)
|
||||
psi(ig,ipol,nbase+np+ic-1) = ew(nbase+np+ic-1) * psi(ig,ipol,nbase+np+ic-1)
|
||||
!
|
||||
END DO
|
||||
END DO
|
||||
END DO
|
||||
!$omp end parallel do
|
||||
!
|
||||
CALL ZGEMM( 'N', 'N', kdim, notcl, nbase, ONE, &
|
||||
hpsi, kdmx, vtmp, nvecx, -ONE, psi(1,1,nb1+ic-1), kdmx )
|
||||
|
||||
!
|
||||
END IF
|
||||
!
|
||||
END DO
|
||||
|
||||
DEALLOCATE( vtmp )
|
||||
DEALLOCATE( ptmp )
|
||||
DEALLOCATE( displs )
|
||||
DEALLOCATE( counts )
|
||||
|
||||
RETURN
|
||||
END SUBROUTINE hpsi_dot_v
|
||||
|
|
|
@ -25,7 +25,8 @@
|
|||
mp_group_create, mp_comm_split, mp_set_displs, &
|
||||
mp_circular_shift_left, &
|
||||
mp_get_comm_null, mp_get_comm_self, mp_count_nodes, &
|
||||
mp_type_create_column_section, mp_type_free
|
||||
mp_type_create_column_section, mp_type_free, &
|
||||
mp_type_create_row, mp_allgather
|
||||
|
||||
!
|
||||
INTERFACE mp_bcast
|
||||
|
@ -72,6 +73,10 @@
|
|||
mp_gatherv_inplace_cplx_column_section
|
||||
END INTERFACE
|
||||
|
||||
INTERFACE mp_allgather
|
||||
MODULE PROCEDURE mp_allgatherv_inplace_cplx_row
|
||||
END INTERFACE
|
||||
|
||||
INTERFACE mp_alltoall
|
||||
MODULE PROCEDURE mp_alltoall_c3d, mp_alltoall_i3d
|
||||
END INTERFACE
|
||||
|
@ -88,6 +93,10 @@
|
|||
MODULE PROCEDURE mp_type_create_cplx_column_section
|
||||
END INTERFACE
|
||||
|
||||
INTERFACE mp_type_create_row
|
||||
MODULE PROCEDURE mp_type_create_cplx_row
|
||||
END INTERFACE
|
||||
|
||||
!------------------------------------------------------------------------------!
|
||||
!
|
||||
CONTAINS
|
||||
|
@ -2005,6 +2014,33 @@
|
|||
RETURN
|
||||
END SUBROUTINE mp_gatherv_inplace_cplx_column_section
|
||||
|
||||
!------------------------------------------------------------------------------!
|
||||
!..mp_allgatherv_inplace_cplx_row
|
||||
!..Ye Luo
|
||||
|
||||
SUBROUTINE mp_allgatherv_inplace_cplx_row(alldata, my_row_type, recvcount, displs, gid)
|
||||
IMPLICIT NONE
|
||||
COMPLEX(DP) :: alldata(:,:)
|
||||
INTEGER, INTENT(IN) :: my_row_type
|
||||
INTEGER, INTENT(IN) :: recvcount(:), displs(:)
|
||||
INTEGER, INTENT(IN) :: gid
|
||||
INTEGER :: ierr, npe, myid
|
||||
|
||||
#if defined (__MPI)
|
||||
CALL mpi_comm_size( gid, npe, ierr )
|
||||
IF (ierr/=0) CALL mp_stop( 8069 )
|
||||
CALL mpi_comm_rank( gid, myid, ierr )
|
||||
IF (ierr/=0) CALL mp_stop( 8070 )
|
||||
!
|
||||
IF ( SIZE( recvcount ) < npe .OR. SIZE( displs ) < npe ) CALL mp_stop( 8071 )
|
||||
!
|
||||
CALL MPI_ALLGATHERV( MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, &
|
||||
alldata, recvcount, displs, my_row_type, gid, ierr )
|
||||
IF (ierr/=0) CALL mp_stop( 8074 )
|
||||
#endif
|
||||
RETURN
|
||||
END SUBROUTINE mp_allgatherv_inplace_cplx_row
|
||||
|
||||
!------------------------------------------------------------------------------!
|
||||
|
||||
SUBROUTINE mp_set_displs( recvcount, displs, ntot, nproc )
|
||||
|
@ -2426,6 +2462,36 @@ SUBROUTINE mp_type_create_cplx_column_section(dummy, start, length, stride, myty
|
|||
RETURN
|
||||
END SUBROUTINE mp_type_create_cplx_column_section
|
||||
|
||||
SUBROUTINE mp_type_create_cplx_row(dummy, ncols, stride, mytype)
|
||||
IMPLICIT NONE
|
||||
!
|
||||
COMPLEX (DP), INTENT(IN) :: dummy
|
||||
INTEGER, INTENT(IN) :: ncols, stride
|
||||
INTEGER, INTENT(OUT) :: mytype
|
||||
!
|
||||
#if defined(__MPI)
|
||||
INTEGER :: ierr
|
||||
INTEGER column_type
|
||||
INTEGER (KIND=MPI_ADDRESS_KIND) lb, sizeofcplx
|
||||
!
|
||||
CALL MPI_TYPE_GET_EXTENT(MPI_DOUBLE_COMPLEX, lb, sizeofcplx, ierr)
|
||||
IF (ierr/=0) CALL mp_stop( 8081 )
|
||||
CALL MPI_TYPE_VECTOR(ncols, 1, stride, MPI_DOUBLE_COMPLEX, column_type, ierr);
|
||||
IF (ierr/=0) CALL mp_stop( 8082 )
|
||||
CALL MPI_TYPE_COMMIT(column_type, ierr);
|
||||
IF (ierr/=0) CALL mp_stop( 8083 )
|
||||
CALL MPI_TYPE_CREATE_RESIZED(column_type, 0, sizeofcplx, mytype, ierr);
|
||||
IF (ierr/=0) CALL mp_stop( 8084 )
|
||||
CALL MPI_TYPE_COMMIT(mytype, ierr);
|
||||
IF (ierr/=0) CALL mp_stop( 8085 )
|
||||
CALL mp_type_free(column_type);
|
||||
#else
|
||||
mytype = 0;
|
||||
#endif
|
||||
!
|
||||
RETURN
|
||||
END SUBROUTINE mp_type_create_cplx_row
|
||||
|
||||
SUBROUTINE mp_type_free(mytype)
|
||||
IMPLICIT NONE
|
||||
INTEGER :: mytype, ierr
|
||||
|
|
Loading…
Reference in New Issue