Respect communicator when initializing laxlib.

This commit is contained in:
Ye Luo 2021-11-12 19:24:37 -06:00
parent 8f5ff19603
commit 13e00cc76f
8 changed files with 29 additions and 100 deletions

View File

@ -97,8 +97,7 @@ SUBROUTINE f2libcpv(lib_comm,nim,npt,npl,nta,nbn,ndg,retval,infile)
!
CALL mp_startup ( my_world_comm=lib_comm )
ndiag_ = ndg
CALL laxlib_start ( ndiag_, world_comm, intra_bgrp_comm, &
do_distr_diag_inside_bgrp_ = diag_in_band_group_)
CALL laxlib_start ( ndiag_, intra_bgrp_comm, do_distr_diag_inside_bgrp_ = diag_in_band_group_)
CALL set_mpi_comm_4_solvers( intra_pool_comm, intra_bgrp_comm, inter_bgrp_comm)
CALL environment_start ( 'CP' )
!

View File

@ -90,7 +90,7 @@ SUBROUTINE f2libpwscf(lib_comm,nim,npt,npl,nta,nbn,ndg,retval,infile)
nband=nbn, ndiag=ndg )
CALL mp_startup ( my_world_comm=lib_comm , start_images = .true. )
ndiag_ = ndg
CALL laxlib_start( ndiag_ , lib_comm, intra_pool_comm, do_distr_diag_inside_bgrp_ = .false.)
CALL laxlib_start( ndiag_, intra_pool_comm, do_distr_diag_inside_bgrp_ = .false.)
CALL set_mpi_comm_4_solvers ( intra_pool_comm, intra_bgrp_comm, inter_bgrp_comm)
CALL environment_start ( 'PWSCF' )
!

View File

@ -104,7 +104,7 @@ END SUBROUTINE
!----------------------------------------------------------------------------
SUBROUTINE laxlib_start_drv( ndiag_, my_world_comm, parent_comm, do_distr_diag_inside_bgrp_ )
SUBROUTINE laxlib_start_drv( ndiag_, parent_comm, do_distr_diag_inside_bgrp_ )
!
use laxlib_processors_grid
USE laxlib_parallel_include
@ -115,47 +115,31 @@ SUBROUTINE laxlib_start_drv( ndiag_, my_world_comm, parent_comm, do_distr_diag_i
IMPLICIT NONE
!
INTEGER, INTENT(INOUT) :: ndiag_ ! (IN) input number of procs in the diag group, (OUT) actual number
INTEGER, INTENT(IN) :: my_world_comm ! parallel communicator of the "local" world
INTEGER, INTENT(IN) :: parent_comm ! parallel communicator inside which the distributed linear algebra group
! communicators are created
LOGICAL, INTENT(IN) :: do_distr_diag_inside_bgrp_ ! comme son nom l'indique
!
INTEGER :: mpime = 0 ! the global MPI task index (used in clocks) can be set with a laxlib_rank call
!
INTEGER :: nproc_ortho_try
INTEGER :: parent_nproc ! nproc of the parent group
INTEGER :: world_nproc ! nproc of the world group
INTEGER :: my_parent_id ! id of the parent communicator
INTEGER :: nparent_comm ! mumber of parent communicators
INTEGER :: ierr = 0
!
IF( lax_is_initialized ) &
CALL laxlib_end_drv ( )
CALL laxlib_end_drv ( )
world_nproc = laxlib_size( my_world_comm ) ! the global number of processors in world_comm
mpime = laxlib_rank( my_world_comm ) ! set the global MPI task index (used in clocks)
parent_nproc = laxlib_size( parent_comm )! the number of processors in the current parent communicator
my_parent_id = mpime / parent_nproc ! set the index of the current parent communicator
nparent_comm = world_nproc/parent_nproc ! number of paren communicators
! initialize blacs world_cntx
call blacs_pinfo(me_blacs, np_blacs)
!
world_cntx = MPI_COMM_WORLD
CALL BLACS_GRIDINIT( world_cntx, 'Row', 1, np_blacs )
parent_nproc = laxlib_size( parent_comm ) ! the number of processors in the current parent communicator
my_parent_id = laxlib_rank( parent_comm ) ! set the index of the current parent communicator
! save input value inside the module
do_distr_diag_inside_bgrp = do_distr_diag_inside_bgrp_
!
#if defined __SCALAPACK
np_blacs = laxlib_size( my_world_comm )
me_blacs = laxlib_rank( my_world_comm )
!
! define a 1D grid containing all MPI tasks of the global communicator
! NOTE: world_cntx has the MPI communicator on entry and the BLACS context on exit
! BLACS_GRIDINIT() will create a copy of the communicator, which can be
! later retrieved using CALL BLACS_GET(world_cntx, 10, comm_copy)
!
world_cntx = my_world_comm
CALL BLACS_GRIDINIT( world_cntx, 'Row', 1, np_blacs )
!
#endif
!
IF( ndiag_ > 0 ) THEN
! command-line argument -ndiag N or -northo N set to a value N
! use the command line value ensuring that it falls in the proper range
@ -173,7 +157,7 @@ SUBROUTINE laxlib_start_drv( ndiag_, my_world_comm, parent_comm, do_distr_diag_i
! the ortho group for parallel linear algebra is a sub-group of the pool,
! then there are as many ortho groups as pools.
!
CALL init_ortho_group ( nproc_ortho_try, my_world_comm, parent_comm, nparent_comm, my_parent_id )
CALL init_ortho_group ( nproc_ortho_try, parent_comm )
!
! set the number of processors in the diag group to the actual number used
!
@ -185,14 +169,11 @@ SUBROUTINE laxlib_start_drv( ndiag_, my_world_comm, parent_comm, do_distr_diag_i
!
CONTAINS
SUBROUTINE init_ortho_group ( nproc_try_in, my_world_comm, comm_all, nparent_comm, my_parent_id )
SUBROUTINE init_ortho_group ( nproc_try_in, comm_all )
!
IMPLICIT NONE
INTEGER, INTENT(IN) :: nproc_try_in, comm_all
INTEGER, INTENT(IN) :: my_world_comm ! parallel communicator of the "local" world
INTEGER, INTENT(IN) :: nparent_comm
INTEGER, INTENT(IN) :: my_parent_id ! id of the parent communicator
INTEGER :: ierr, color, key, me_all, nproc_all, nproc_try
@ -206,7 +187,6 @@ CONTAINS
#if defined __MPI
me_all = laxlib_rank( comm_all )
!
nproc_all = laxlib_size( comm_all )
!
nproc_try = MIN( nproc_try_in, nproc_all )
@ -282,68 +262,21 @@ CONTAINS
CALL laxlib_comm_split( ortho_comm, me_ortho(2), me_ortho(1), ortho_col_comm)
CALL laxlib_comm_split( ortho_comm, me_ortho(1), me_ortho(2), ortho_row_comm)
#if defined __SCALAPACK
!
ortho_cntx = ortho_comm
! ortho_cntx is both an input and output. input is a system context. output is a BLACS context.
! In Fortran, a system context is just a MPI communicator. To avoid any unintended behavior,
! make sure ortho_comm matches exactly the BLACS processor grid.
call BLACS_GRIDINIT(ortho_cntx, 'R', np_ortho(1), np_ortho(2))
#endif
else
ortho_comm_id = 0
me_ortho(1) = me_ortho1
me_ortho(2) = me_ortho1
endif
#if defined __SCALAPACK
!
! This part is used to eliminate the image dependency from ortho groups
! SCALAPACK is now independent from whatever level of parallelization
! is present on top of pool parallelization
!
ALLOCATE( ortho_cntx_pe( nparent_comm ) )
ALLOCATE( blacsmap( np_ortho(1), np_ortho(2) ) )
DO j = 1, nparent_comm
CALL BLACS_GET(world_cntx, 10, ortho_cntx_pe( j ) ) ! retrieve communicator of world context
blacsmap = 0
nprow = np_ortho(1)
npcol = np_ortho(2)
IF( ( j == ( my_parent_id + 1 ) ) .and. ( ortho_comm_id > 0 ) ) THEN
blacsmap( me_ortho(1) + 1, me_ortho(2) + 1 ) = BLACS_PNUM( world_cntx, 0, me_blacs )
END IF
! All MPI tasks defined in the global communicator take part in the definition of the BLACS grid
CALL MPI_ALLREDUCE( MPI_IN_PLACE, blacsmap, SIZE(blacsmap), MPI_INTEGER, MPI_SUM, my_world_comm, ierr )
IF( ierr /= 0 ) &
CALL lax_error__( ' init_ortho_group ', ' problem in MPI_ALLREDUCE of blacsmap ', ierr )
CALL BLACS_GRIDMAP( ortho_cntx_pe( j ), blacsmap, nprow, nprow, npcol )
CALL BLACS_GRIDINFO( ortho_cntx_pe( j ), nprow, npcol, myrow, mycol )
IF( ( j == ( my_parent_id + 1 ) ) .and. ( ortho_comm_id > 0 ) ) THEN
IF( np_ortho(1) /= nprow ) &
CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong no. of task rows ', 1 )
IF( np_ortho(2) /= npcol ) &
CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong no. of task columns ', 1 )
IF( me_ortho(1) /= myrow ) &
CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong task row ID ', 1 )
IF( me_ortho(2) /= mycol ) &
CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong task columns ID ', 1 )
ortho_cntx = ortho_cntx_pe( j )
END IF
END DO
DEALLOCATE( blacsmap )
DEALLOCATE( ortho_cntx_pe )
! end SCALAPACK code block
#endif
#else
ortho_comm_id = 1

View File

@ -8,10 +8,9 @@
!
INTERFACE laxlib_start
SUBROUTINE laxlib_start_drv( ndiag_, my_world_comm, parent_comm, do_distr_diag_inside_bgrp_ )
SUBROUTINE laxlib_start_drv( ndiag_, parent_comm, do_distr_diag_inside_bgrp_ )
IMPLICIT NONE
INTEGER, INTENT(INOUT) :: ndiag_ ! (IN) input number of procs in the diag group, (OUT) actual number
INTEGER, INTENT(IN) :: my_world_comm ! parallel communicator of the "local" world
INTEGER, INTENT(IN) :: parent_comm ! parallel communicator inside which the distributed linear algebra group
! communicators are created
LOGICAL, INTENT(IN) :: do_distr_diag_inside_bgrp_ ! comme son nom l'indique

View File

@ -338,7 +338,7 @@
!
!
n_diag = n
CALL laxlib_start(n_diag, mpi_comm_world, mpi_comm_world, do_distr_diag_inside_bgrp)
CALL laxlib_start(n_diag, mpi_comm_world, do_distr_diag_inside_bgrp)
CALL laxlib_getval( np_ortho = np_ortho, ortho_comm = ortho_comm, &
do_distr_diag_inside_bgrp = do_distr_diag_inside_bgrp )
!

View File

@ -39,11 +39,11 @@ SUBROUTINE set_para_diag( nbnd, use_para_diag )
IF( negrp > 1 .OR. do_diag_in_band_group ) THEN
! one diag group per bgrp with strict hierarchy: POOL > BAND > DIAG
! if using exx groups from mp_exx, always use this diag method
CALL laxlib_start ( ndiag_, world_comm, intra_bgrp_comm, .TRUE. )
CALL laxlib_start ( ndiag_, intra_bgrp_comm, .TRUE. )
ELSE
! one diag group per pool ( individual k-point level )
! with band group and diag group both being children of POOL comm
CALL laxlib_start ( ndiag_, world_comm, intra_pool_comm, .FALSE. )
CALL laxlib_start ( ndiag_, intra_pool_comm, .FALSE. )
END IF
CALL set_mpi_comm_4_solvers( intra_pool_comm, intra_bgrp_comm, &
inter_bgrp_comm )

View File

@ -127,8 +127,7 @@ program all_currents
!from ../PW/src/pwscf.f90
CALL mp_startup()
CALL laxlib_start(ndiag_, world_comm, intra_bgrp_comm, &
do_distr_diag_inside_bgrp_=.TRUE.)
CALL laxlib_start(ndiag_, intra_bgrp_comm, do_distr_diag_inside_bgrp_=.TRUE.)
CALL set_mpi_comm_4_solvers(intra_pool_comm, intra_bgrp_comm, &
inter_bgrp_comm)
CALL environment_start('QEHeat')

View File

@ -59,8 +59,7 @@ PROGRAM lr_magnons_main
pol_index = 1
!
CALL mp_startup ( )
CALL laxlib_start ( ndiag_, world_comm, intra_bgrp_comm, &
do_distr_diag_inside_bgrp_ = .true. )
CALL laxlib_start ( ndiag_, intra_bgrp_comm, do_distr_diag_inside_bgrp_ = .true. )
CALL set_mpi_comm_4_solvers( intra_pool_comm, intra_bgrp_comm, &
inter_bgrp_comm )
!