- Scalapack context parameter added to the linear algebra descriptor

- Linear Algebra Library self testing program, now performs
  a matrix diagonalization and print out diagonalization timing.
  It can be useful to select the best low level library and/or the 
  best number of processors to be used for linear algebra.
  More validation, auto-testing and performance assesment will follow



git-svn-id: http://qeforge.qe-forge.org/svn/q-e/trunk/espresso@12114 c92efa57-630b-4861-b058-cf58834340f0
This commit is contained in:
ccavazzoni 2016-02-12 09:41:25 +00:00
parent a857349b53
commit 023d18237b
14 changed files with 435 additions and 134 deletions

View File

@ -171,7 +171,8 @@
coor_ip(1) = ipr - 1
coor_ip(2) = ipc - 1
CALL descla_init( desc_ip, descla( is )%n, descla( is )%nx, np, coor_ip, descla( is )%comm, 1 )
CALL descla_init( desc_ip, descla( is )%n, descla( is )%nx, np, coor_ip, &
descla( is )%comm, descla( is )%cntx, 1 )
nr = desc_ip%nr
nc = desc_ip%nc

View File

@ -115,7 +115,7 @@ MODULE cp_main_variables
!------------------------------------------------------------------------
!
USE mp_global, ONLY: np_ortho, me_ortho, intra_bgrp_comm, ortho_comm, &
me_bgrp, ortho_comm_id
me_bgrp, ortho_comm_id, ortho_cntx
USE mp, ONLY: mp_max, mp_min
USE descriptors, ONLY: la_descriptor, descla_init
!
@ -199,7 +199,7 @@ MODULE cp_main_variables
ALLOCATE( descla( nspin ) )
!
DO iss = 1, nspin
CALL descla_init( descla( iss ), nupdwn( iss ), nudx, np_ortho, me_ortho, ortho_comm, ortho_comm_id )
CALL descla_init( descla( iss ), nupdwn( iss ), nudx, np_ortho, me_ortho, ortho_comm, ortho_cntx, ortho_comm_id )
END DO
!
nrcx = MAXVAL( descla( : )%nrcx )

View File

@ -14,8 +14,8 @@
!
USE kinds, ONLY: DP
USE orthogonalize_base, ONLY: rhoset, sigset, tauset, ortho_iterate, &
ortho_alt_iterate, diagonalize_serial, &
use_parallel_diag, diagonalize_parallel
ortho_alt_iterate, use_parallel_diag
USE dspev_module, ONLY: diagonalize_serial, diagonalize_parallel
USE descriptors, ONLY: la_descriptor
USE mp_global, ONLY: nproc_bgrp, me_bgrp, intra_bgrp_comm, my_bgrp_id, inter_bgrp_comm, nbgrp
USE mp, ONLY: mp_sum, mp_bcast

View File

@ -11,7 +11,7 @@ MODULE orthogonalize_base
USE kinds
USE dspev_module, ONLY: pdspev_drv, dspev_drv
USE dspev_module, ONLY: diagonalize_serial, diagonalize_parallel
IMPLICIT NONE
@ -34,113 +34,16 @@ MODULE orthogonalize_base
PUBLIC :: ortho_iterate
PUBLIC :: ortho_alt_iterate
PUBLIC :: updatc, calphi_bgrp
PUBLIC :: mesure_diag_perf
PUBLIC :: mesure_mmul_perf
PUBLIC :: diagonalize_parallel
PUBLIC :: diagonalize_serial
PUBLIC :: mesure_diag_perf, mesure_mmul_perf
PUBLIC :: use_parallel_diag
PUBLIC :: bec_bgrp2ortho
CONTAINS
! ----------------------------------------------
SUBROUTINE diagonalize_serial( n, rhos, rhod )
IMPLICIT NONE
INTEGER, INTENT(IN) :: n
REAL(DP) :: rhos(:,:)
REAL(DP) :: rhod(:)
!
! inputs:
! n size of the eigenproblem
! rhos the symmetric matrix
! outputs:
! rhos eigenvectors
! rhod eigenvalues
!
REAL(DP), ALLOCATABLE :: aux(:)
INTEGER :: i, j, k
IF( n < 1 ) RETURN
ALLOCATE( aux( n * ( n + 1 ) / 2 ) )
! pack lower triangle of rho into aux
!
k = 0
DO j = 1, n
DO i = j, n
k = k + 1
aux( k ) = rhos( i, j )
END DO
END DO
CALL dspev_drv( 'V', 'L', n, aux, rhod, rhos, SIZE(rhos,1) )
DEALLOCATE( aux )
RETURN
END SUBROUTINE diagonalize_serial
! ----------------------------------------------
SUBROUTINE diagonalize_parallel( n, rhos, rhod, s, desc )
USE descriptors
#ifdef __SCALAPACK
USE mp_global, ONLY: ortho_cntx, ortho_comm
USE dspev_module, ONLY: pdsyevd_drv
#endif
IMPLICIT NONE
REAL(DP), INTENT(IN) :: rhos(:,:) ! input symmetric matrix
REAL(DP) :: rhod(:) ! output eigenvalues
REAL(DP) :: s(:,:) ! output eigenvectors
INTEGER, INTENT(IN) :: n ! size of the global matrix
TYPE(la_descriptor), INTENT(IN) :: desc
IF( n < 1 ) RETURN
! Matrix is distributed on the same processors group
! used for parallel matrix multiplication
!
IF( SIZE(s,1) /= SIZE(rhos,1) .OR. SIZE(s,2) /= SIZE(rhos,2) ) &
CALL errore( " diagonalize_parallel ", " inconsistent dimension for s and rhos ", 1 )
IF ( desc%active_node > 0 ) THEN
!
IF( SIZE(s,1) /= desc%nrcx ) &
CALL errore( " diagonalize_parallel ", " inconsistent dimension ", 1 )
!
! Compute local dimension of the cyclically distributed matrix
!
s = rhos
!
#ifdef __SCALAPACK
CALL pdsyevd_drv( .true. , n, desc%nrcx, s, SIZE(s,1), rhod, ortho_cntx, ortho_comm )
#else
CALL qe_pdsyevd( .true., n, desc, s, SIZE(s,1), rhod )
#endif
!
END IF
RETURN
END SUBROUTINE diagonalize_parallel
! ----------------------------------------------
SUBROUTINE mesure_diag_perf( n )
!
USE mp_global, ONLY: nproc_bgrp, me_bgrp, intra_bgrp_comm, root_bgrp
USE mp_global, ONLY: nproc_ortho, np_ortho, me_ortho, ortho_comm, ortho_comm_id
USE mp_global, ONLY: nproc_ortho, np_ortho, me_ortho, ortho_comm, ortho_cntx, ortho_comm_id
USE io_global, ONLY: ionode, stdout
USE mp, ONLY: mp_sum, mp_bcast, mp_barrier
USE mp, ONLY: mp_max
@ -166,7 +69,7 @@ END SUBROUTINE diagonalize_parallel
ALLOCATE( d( n ) )
!
CALL descla_init( desc, n, n, np_ortho, me_ortho, ortho_comm, ortho_comm_id )
CALL descla_init( desc, n, n, np_ortho, me_ortho, ortho_comm, ortho_cntx, ortho_comm_id )
nx = 1
IF( desc%active_node > 0 ) nx = desc%nrcx
@ -275,7 +178,6 @@ END SUBROUTINE diagonalize_parallel
END SUBROUTINE set_a
END SUBROUTINE mesure_diag_perf
! ----------------------------------------------
@ -285,7 +187,7 @@ END SUBROUTINE diagonalize_parallel
USE mp_bands, ONLY: nproc_bgrp, me_bgrp, intra_bgrp_comm, &
root_bgrp
USE mp_diag, ONLY: ortho_comm, nproc_ortho, np_ortho, &
me_ortho, init_ortho_group, ortho_comm_id
me_ortho, init_ortho_group, ortho_comm_id, ortho_cntx
USE io_global, ONLY: ionode, stdout
USE mp, ONLY: mp_sum, mp_bcast, mp_barrier
USE mp, ONLY: mp_max
@ -313,7 +215,7 @@ END SUBROUTINE diagonalize_parallel
!
CALL init_ortho_group( np * np, intra_bgrp_comm )
CALL descla_init( desc, n, n, np_ortho, me_ortho, ortho_comm, ortho_comm_id )
CALL descla_init( desc, n, n, np_ortho, me_ortho, ortho_comm, ortho_cntx, ortho_comm_id )
nr = desc%nr
nc = desc%nc
@ -749,7 +651,7 @@ END SUBROUTINE diagonalize_parallel
coor_ip(1) = ipr - 1
coor_ip(2) = ipc - 1
CALL descla_init( desc_ip, desc%n, desc%nx, np, coor_ip, desc%comm, 1 )
CALL descla_init( desc_ip, desc%n, desc%nx, np, coor_ip, desc%comm, desc%cntx, 1 )
nr = desc_ip%nr
nc = desc_ip%nc
@ -889,7 +791,7 @@ END SUBROUTINE diagonalize_parallel
coor_ip(1) = ipr - 1
coor_ip(2) = ipc - 1
CALL descla_init( desc_ip, desc%n, desc%nx, np, coor_ip, desc%comm, 1 )
CALL descla_init( desc_ip, desc%n, desc%nx, np, coor_ip, desc%comm, desc%cntx, 1 )
nr = desc_ip%nr
nc = desc_ip%nc
@ -1027,7 +929,7 @@ END SUBROUTINE diagonalize_parallel
coor_ip(1) = ipr - 1
coor_ip(2) = ipc - 1
CALL descla_init( desc_ip, desc%n, desc%nx, np, coor_ip, desc%comm, 1 )
CALL descla_init( desc_ip, desc%n, desc%nx, np, coor_ip, desc%comm, desc%cntx, 1 )
nr = desc_ip%nr
nc = desc_ip%nc
@ -1213,7 +1115,7 @@ END SUBROUTINE diagonalize_parallel
coor_ip(1) = ipr - 1
coor_ip(2) = ipc - 1
CALL descla_init( desc_ip, desc( iss )%n, desc( iss )%nx, np, coor_ip, desc( iss )%comm, 1 )
CALL descla_init( desc_ip, desc( iss )%n, desc( iss )%nx, np, coor_ip, desc( iss )%comm, desc( iss )%cntx, 1 )
nr = desc_ip%nr
nc = desc_ip%nc

View File

@ -19,6 +19,7 @@ MODULE dspev_module
PRIVATE
PUBLIC :: pdspev_drv, dspev_drv
PUBLIC :: diagonalize_parallel, diagonalize_serial
#if defined __SCALAPACK
PUBLIC :: pdsyevd_drv
@ -759,4 +760,88 @@ CONTAINS
#endif
! ----------------------------------------------
! Simplified driver
SUBROUTINE diagonalize_parallel( n, rhos, rhod, s, desc )
USE descriptors
IMPLICIT NONE
REAL(DP), INTENT(IN) :: rhos(:,:) ! input symmetric matrix
REAL(DP) :: rhod(:) ! output eigenvalues
REAL(DP) :: s(:,:) ! output eigenvectors
INTEGER, INTENT(IN) :: n ! size of the global matrix
TYPE(la_descriptor), INTENT(IN) :: desc
IF( n < 1 ) RETURN
! Matrix is distributed on the same processors group
! used for parallel matrix multiplication
!
IF( SIZE(s,1) /= SIZE(rhos,1) .OR. SIZE(s,2) /= SIZE(rhos,2) ) &
CALL lax_error__( " diagonalize_parallel ", " inconsistent dimension for s and rhos ", 1 )
IF ( desc%active_node > 0 ) THEN
!
IF( SIZE(s,1) /= desc%nrcx ) &
CALL lax_error__( " diagonalize_parallel ", " inconsistent dimension ", 1)
!
! Compute local dimension of the cyclically distributed matrix
!
s = rhos
!
#ifdef __SCALAPACK
CALL pdsyevd_drv( .true. , n, desc%nrcx, s, SIZE(s,1), rhod, desc%cntx, desc%comm )
#else
CALL qe_pdsyevd( .true., n, desc, s, SIZE(s,1), rhod )
#endif
!
END IF
RETURN
END SUBROUTINE diagonalize_parallel
SUBROUTINE diagonalize_serial( n, rhos, rhod )
IMPLICIT NONE
INTEGER, INTENT(IN) :: n
REAL(DP) :: rhos(:,:)
REAL(DP) :: rhod(:)
!
! inputs:
! n size of the eigenproblem
! rhos the symmetric matrix
! outputs:
! rhos eigenvectors
! rhod eigenvalues
!
REAL(DP), ALLOCATABLE :: aux(:)
INTEGER :: i, j, k
IF( n < 1 ) RETURN
ALLOCATE( aux( n * ( n + 1 ) / 2 ) )
! pack lower triangle of rho into aux
!
k = 0
DO j = 1, n
DO i = j, n
k = k + 1
aux( k ) = rhos( i, j )
END DO
END DO
CALL dspev_drv( 'V', 'L', n, aux, rhod, rhos, SIZE(rhos,1) )
DEALLOCATE( aux )
RETURN
END SUBROUTINE diagonalize_serial
END MODULE dspev_module

View File

@ -34,6 +34,7 @@
INTEGER :: myr = 0 ! processor row index
INTEGER :: myc = 0 ! processor column index
INTEGER :: comm = 0 ! communicator
INTEGER :: cntx =-1 ! scalapack context
INTEGER :: mype = 0 ! processor index ( from 0 to desc( la_npr_ ) * desc( la_npc_ ) - 1 )
INTEGER :: nrl = 0 ! number of local rows, when the matrix rows are cyclically distributed across proc
INTEGER :: nrlx = 0 ! leading dimension, when the matrix is distributed by row
@ -76,14 +77,14 @@
END SUBROUTINE descla_local_dims
!
!
SUBROUTINE descla_init( descla, n, nx, np, me, comm, includeme )
SUBROUTINE descla_init( descla, n, nx, np, me, comm, cntx, includeme )
!
IMPLICIT NONE
TYPE(la_descriptor), INTENT(OUT) :: descla
INTEGER, INTENT(IN) :: n ! the size of this matrix
INTEGER, INTENT(IN) :: nx ! the max among different matrixes sharing
! this descriptor or the same data distribution
INTEGER, INTENT(IN) :: np(2), me(2), comm
INTEGER, INTENT(IN) :: np(2), me(2), comm, cntx
INTEGER, INTENT(IN) :: includeme
INTEGER :: ir, nr, ic, nc, lnode, nrcx, nrl, nrlx
INTEGER :: ip, npp
@ -101,11 +102,13 @@
#if __SCALAPACK
nrcx = ldim_block_sca( nx, np(1), 0 )
descla%cntx = cntx
#else
nrcx = ldim_block( nx, np(1), 0 )
DO ip = 1, np(1) - 1
nrcx = MAX( nrcx, ldim_block( nx, np(1), ip ) )
END DO
descla%cntx = -1
#endif
!
! find local dimensions, if appropriate

View File

@ -2653,7 +2653,7 @@ SUBROUTINE cyc2blk_redist( n, a, lda, nca, b, ldb, ncb, desc )
!
! initialize other processor descriptor
!
CALL descla_init( ip_desc, desc%n, desc%nx, np_ortho, me_ortho, desc%comm, 1 )
CALL descla_init( ip_desc, desc%n, desc%nx, np_ortho, me_ortho, desc%comm, desc%cntx, 1 )
IF( ip_desc%nrcx /= nb ) &
CALL lax_error__( ' cyc2blk_redist ', ' inconsistent block dim nb ', 1 )
@ -2817,7 +2817,7 @@ SUBROUTINE cyc2blk_zredist( n, a, lda, nca, b, ldb, ncb, desc )
!
! initialize other processor descriptor
!
CALL descla_init( ip_desc, desc%n, desc%nx, np_ortho, me_ortho, desc%comm, 1 )
CALL descla_init( ip_desc, desc%n, desc%nx, np_ortho, me_ortho, desc%comm, desc%cntx, 1 )
ip_nr = ip_desc%nr
ip_nc = ip_desc%nc
@ -3001,7 +3001,7 @@ SUBROUTINE blk2cyc_redist( n, a, lda, nca, b, ldb, ncb, desc )
!
! initialize other processor descriptor
!
CALL descla_init( ip_desc, desc%n, desc%nx, np_ortho, me_ortho, desc%comm, 1 )
CALL descla_init( ip_desc, desc%n, desc%nx, np_ortho, me_ortho, desc%comm, desc%cntx, 1 )
!
ip_nr = ip_desc%nr
ip_nc = ip_desc%nc
@ -3132,7 +3132,7 @@ SUBROUTINE blk2cyc_zredist( n, a, lda, nca, b, ldb, ncb, desc )
!
! initialize other processor descriptor
!
CALL descla_init( ip_desc, desc%n, desc%nx, np_ortho, me_ortho, desc%comm, 1 )
CALL descla_init( ip_desc, desc%n, desc%nx, np_ortho, me_ortho, desc%comm, desc%cntx, 1 )
!
ip_nr = ip_desc%nr
ip_nc = ip_desc%nc

View File

@ -1,3 +1,312 @@
program lax_test
call qe_pdsyevd
use descriptors
use dspev_module
IMPLICIT NONE
#ifdef __MPI
include 'mpif.h'
#endif
#include "la_param.f90"
INTEGER :: mype, npes, comm, ntgs, root
LOGICAL :: iope
INTEGER :: ierr
INTEGER :: stdout
!
INTEGER :: np_ortho(2) = 1 ! size of the processor grid used in ortho
INTEGER :: me_ortho(2) = 0 ! coordinates of the processors
INTEGER :: me_ortho1 = 0 ! task id for the ortho group
INTEGER :: nproc_ortho = 1 ! size of the ortho group:
! of two neighbour processors in ortho_comm
INTEGER :: ortho_comm = 0 ! communicator for the ortho group
INTEGER :: ortho_row_comm = 0 ! communicator for the ortho row group
INTEGER :: ortho_col_comm = 0 ! communicator for the ortho col group
INTEGER :: ortho_comm_id= 0 ! id of the ortho_comm
INTEGER :: ortho_parent_comm = 0 ! parent communicator from which ortho group has been created
!
#if defined __SCALAPACK
INTEGER :: me_blacs = 0 ! BLACS processor index starting from 0
INTEGER :: np_blacs = 1 ! BLACS number of processor
#endif
!
INTEGER :: world_cntx = -1 ! BLACS context of all processor
INTEGER :: ortho_cntx = -1 ! BLACS context for ortho_comm
!
REAL(DP), ALLOCATABLE :: a(:,:)
REAL(DP), ALLOCATABLE :: s(:,:)
REAL(DP), ALLOCATABLE :: d(:)
!
REAL(DP) :: time1, time2
TYPE(la_descriptor) :: desc
INTEGER :: i, ir, ic, nx, n, nr, nc ! size of the matrix
!
#if defined(__OPENMP)
INTEGER :: PROVIDED
#endif
!
! ........
!
#ifdef __MPI
#if defined(__OPENMP)
CALL MPI_Init_thread(MPI_THREAD_FUNNELED, PROVIDED, ierr)
#else
CALL MPI_Init(ierr)
#endif
CALL mpi_comm_rank(MPI_COMM_WORLD,mype,ierr)
CALL mpi_comm_size(MPI_COMM_WORLD,npes,ierr)
comm = MPI_COMM_WORLD
ntgs = 1
root = 0
IF(mype==root) THEN
iope = .true.
ELSE
iope = .false.
ENDIF
#else
mype = 0
npes = 1
comm = 0
ntgs = 1
root = 0
iope = .true.
#endif
!
!write(*,*) 'mype = ', mype, ' npes = ', npes
!
n = 1024
call mp_start_diag()
!
CALL descla_init( desc, n, n, np_ortho, me_ortho, ortho_comm, ortho_cntx, ortho_comm_id )
!
nx = 1
IF( desc%active_node > 0 ) nx = desc%nrcx
!
ALLOCATE( d( n ) )
ALLOCATE( s( nx, nx ) )
ALLOCATE( a( nx, nx ) )
nr = desc%nr
nc = desc%nc
ir = desc%ir
ic = desc%ic
!
! do not take the time of the first execution, it may be biased from MPI
! initialization stuff
!
CALL set_a()
!
CALL diagonalize_parallel( n, a, d, s, desc )
!
CALL set_a()
!
CALL MPI_BARRIER( MPI_COMM_WORLD, ierr)
time1 = MPI_WTIME()
!
CALL diagonalize_parallel( n, a, d, s, desc )
!
CALL MPI_BARRIER( MPI_COMM_WORLD, ierr)
time2 = MPI_WTIME()
!
IF( mype == 0 ) THEN
write(*,*) ' Matrix eigenvalues '
IF ( n <= 16 ) THEN
DO i = 1, n
write(*,*) ' D(',i,')=',d(i)
END DO
ELSE
DO i = 1, 8
write(*,*) ' D(',i,')=',d(i)
END DO
write(*,*) ' ... '
DO i = n-8, n
write(*,*) ' D(',i,')=',d(i)
END DO
END IF
write(*,*) ' Matrix size = ', n, ' Diagonalization wall time = ', time2-time1
ENDIF
#ifdef __MPI
CALL mpi_finalize(ierr)
#endif
contains
!----------------------------------------------------------------------------
SUBROUTINE mp_start_diag( )
!---------------------------------------------------------------------------
!
! ... Ortho/diag/linear algebra group initialization
!
IMPLICIT NONE
!
INTEGER :: ierr = 0
INTEGER :: color, key, nproc_try
#if defined __SCALAPACK
INTEGER, ALLOCATABLE :: blacsmap(:,:)
INTEGER :: ortho_cntx_pe
INTEGER :: nprow, npcol, myrow, mycol, i, j, k
INTEGER, EXTERNAL :: BLACS_PNUM
!
INTEGER :: nparent=1
INTEGER :: total_nproc=1
INTEGER :: total_mype=0
INTEGER :: nproc_parent=1
INTEGER :: my_parent_id=0
#endif
!
#if defined __SCALAPACK
!
CALL mpi_comm_rank( MPI_COMM_WORLD, me_blacs, ierr)
CALL mpi_comm_size( MPI_COMM_WORLD, np_blacs, ierr)
!
! define a 1D grid containing all MPI tasks of the global communicator
! NOTE: world_cntx has the MPI communicator on entry and the BLACS context
! on exit
! BLACS_GRID_INIT() will create a copy of the communicator, which can
! be
! later retrieved using CALL BLACS_GET(world_cntx, 10, comm_copy)
!
world_cntx = MPI_COMM_WORLD
CALL BLACS_GRIDINIT( world_cntx, 'Row', 1, np_blacs )
!
#endif
!
! the ortho group for parallel linear algebra is a sub-group of the pool,
! then there are as many ortho groups as pools.
!
#if defined __MPI
nproc_try = MAX( npes, 1 )
! find the square closer (but lower) to nproc_try
!
CALL grid2d_dims( 'S', nproc_try, np_ortho(1), np_ortho(2) )
!
! now, and only now, it is possible to define the number of tasks
! in the ortho group for parallel linear algebra
!
nproc_ortho = np_ortho(1) * np_ortho(2)
!
! here we choose the first "nproc_ortho" processors
!
color = 0
IF( mype < nproc_ortho ) color = 1
!
key = mype
!
! initialize the communicator for the new group by splitting the input
! communicator
!
CALL mpi_comm_split( MPI_COMM_WORLD , color, key, ortho_comm, ierr )
!
! Computes coordinates of the processors, in row maior order
!
CALL mpi_comm_rank( ortho_comm, me_ortho1, ierr)
!
IF( mype == 0 .AND. me_ortho1 /= 0 ) &
CALL lax_error__( " init_ortho_group ", " wrong root task in ortho group ", ierr )
!
if( color == 1 ) then
! this task belong to the ortho_group compute its coordinates
ortho_comm_id = 1
CALL GRID2D_COORDS( 'R', me_ortho1, np_ortho(1), np_ortho(2), me_ortho(1), me_ortho(2) )
CALL GRID2D_RANK( 'R', np_ortho(1), np_ortho(2), me_ortho(1), me_ortho(2), ierr )
IF( ierr /= me_ortho1 ) &
CALL lax_error__( " init_ortho_group ", " wrong task coordinates in ortho group ", ierr )
IF( me_ortho1 /= mype ) &
CALL lax_error__( " init_ortho_group ", " wrong rank assignment in ortho group ", ierr )
CALL mpi_comm_split( ortho_comm , me_ortho(2), me_ortho(1), ortho_col_comm, ierr )
CALL mpi_comm_split( ortho_comm , me_ortho(1), me_ortho(2), ortho_row_comm, ierr )
else
! this task does NOT belong to the ortho_group set dummy values
ortho_comm_id = 0
me_ortho(1) = me_ortho1
me_ortho(2) = me_ortho1
endif
#if defined __SCALAPACK
!
! This part is used to eliminate the image dependency from ortho groups
! SCALAPACK is now independent of whatever level of parallelization
! is present on top of pool parallelization
!
total_nproc = npes
total_mype = mype
!
ALLOCATE( blacsmap( np_ortho(1), np_ortho(2) ) )
CALL BLACS_GET( world_cntx, 10, ortho_cntx_pe ) ! retrieve communicator of world context
blacsmap = 0
nprow = np_ortho(1)
npcol = np_ortho(2)
IF( ortho_comm_id > 0 ) THEN
blacsmap( me_ortho(1) + 1, me_ortho(2) + 1 ) = BLACS_PNUM( world_cntx, 0, me_blacs )
END IF
! All MPI tasks defined in the global communicator take part in the definition of the BLACS grid
CALL MPI_ALLREDUCE( MPI_IN_PLACE, blacsmap, SIZE( blacsmap ), MPI_INTEGER, MPI_SUM, MPI_COMM_WORLD, ierr )
CALL BLACS_GRIDMAP( ortho_cntx_pe, blacsmap, nprow, nprow, npcol)
CALL BLACS_GRIDINFO( ortho_cntx_pe, nprow, npcol, myrow, mycol )
IF( ortho_comm_id > 0) THEN
IF( np_ortho(1) /= nprow ) &
CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong no. of task rows ', 1 )
IF( np_ortho(2) /= npcol ) &
CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong no. of task columns ', 1 )
IF( me_ortho(1) /= myrow ) &
CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong task row ID ', 1 )
IF( me_ortho(2) /= mycol ) &
CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong task columns ID ', 1 )
ortho_cntx = ortho_cntx_pe
END IF
DEALLOCATE( blacsmap )
#endif
#else
ortho_comm_id = 1
#endif
RETURN
END SUBROUTINE
SUBROUTINE set_a()
INTEGER :: i, j, ii, jj
IF( desc%active_node < 0 ) RETURN
DO j = 1, nc
DO i = 1, nr
ii = i + ir - 1
jj = j + ic - 1
IF( ii == jj ) THEN
a(i,j) = ( DBLE( n-ii+1 ) ) / DBLE( n ) + 1.0d0 / ( DBLE( ii+jj ) - 1.0d0 )
ELSE
a(i,j) = 1.0d0 / ( DBLE( ii+jj ) - 1.0d0 )
END IF
END DO
END DO
RETURN
END SUBROUTINE set_a
end program lax_test

View File

@ -42,9 +42,10 @@ MODULE mp_diag
#if defined __SCALAPACK
INTEGER :: me_blacs = 0 ! BLACS processor index starting from 0
INTEGER :: np_blacs = 1 ! BLACS number of processor
#endif
!
INTEGER :: world_cntx = -1 ! BLACS context of all processor
INTEGER :: ortho_cntx = -1 ! BLACS context for ortho_comm
#endif
!
CONTAINS
!

View File

@ -1875,7 +1875,7 @@ SUBROUTINE pprojwave( filproj, lsym, lwrite_ovp, lbinary )
USE mp_global, ONLY : npool, me_pool, root_pool, &
intra_pool_comm, me_image, &
ortho_comm, np_ortho, me_ortho, ortho_comm_id, &
leg_ortho
leg_ortho, ortho_cntx
USE wavefunctions_module, ONLY: evc
USE parallel_toolkit, ONLY : zsqmred, zsqmher, zsqmdst, zsqmcll, dsqmsym
USE zhpev_module, ONLY : pzhpev_drv, zhpev_drv
@ -2426,7 +2426,7 @@ CONTAINS
INTEGER :: i, j, rank
INTEGER :: coor_ip( 2 )
!
CALL descla_init( desc, nsiz, nsiz, np_ortho, me_ortho, ortho_comm, ortho_comm_id )
CALL descla_init( desc, nsiz, nsiz, np_ortho, me_ortho, ortho_comm, ortho_cntx, ortho_comm_id )
!
nx = desc%nrcx
!
@ -2434,7 +2434,7 @@ CONTAINS
DO i = 0, desc%npr - 1
coor_ip( 1 ) = i
coor_ip( 2 ) = j
CALL descla_init( desc_ip(i+1,j+1), desc%n, desc%nx, np_ortho, coor_ip, ortho_comm, 1 )
CALL descla_init( desc_ip(i+1,j+1), desc%n, desc%nx, np_ortho, coor_ip, ortho_comm, ortho_cntx, 1 )
CALL GRID2D_RANK( 'R', desc%npr, desc%npc, i, j, rank )
rank_ip( i+1, j+1 ) = rank * leg_ortho
ENDDO

View File

@ -485,7 +485,7 @@ SUBROUTINE pcegterg( npw, npwx, nvec, nvecx, npol, evc, ethr, &
USE io_global, ONLY : stdout
USE mp_bands, ONLY : intra_bgrp_comm, inter_bgrp_comm, root_bgrp, nbgrp
USE mp_diag, ONLY : ortho_comm, np_ortho, me_ortho, ortho_comm_id, leg_ortho, &
ortho_parent_comm
ortho_parent_comm, ortho_cntx
USE descriptors, ONLY : la_descriptor, descla_init , descla_local_dims
USE parallel_toolkit, ONLY : zsqmred, zsqmher, zsqmdst
USE mp, ONLY : mp_bcast, mp_root_sum, mp_sum, mp_barrier
@ -964,7 +964,7 @@ CONTAINS
INTEGER, INTENT(OUT) :: nrc_ip(:)
INTEGER :: i, j, rank
!
CALL descla_init( desc, nsiz, nsiz, np_ortho, me_ortho, ortho_comm, ortho_comm_id )
CALL descla_init( desc, nsiz, nsiz, np_ortho, me_ortho, ortho_comm, ortho_cntx, ortho_comm_id )
!
nx = desc%nrcx
!

View File

@ -480,7 +480,7 @@ SUBROUTINE pregterg( npw, npwx, nvec, nvecx, evc, ethr, &
USE io_global, ONLY : stdout
USE mp_bands, ONLY : intra_bgrp_comm, inter_bgrp_comm, root_bgrp, nbgrp
USE mp_diag, ONLY : ortho_comm, np_ortho, me_ortho, ortho_comm_id, leg_ortho, &
ortho_parent_comm
ortho_parent_comm, ortho_cntx
USE descriptors, ONLY : la_descriptor, descla_init, descla_local_dims
USE parallel_toolkit, ONLY : dsqmdst, dsqmcll, dsqmred, dsqmsym
USE mp, ONLY : mp_bcast, mp_root_sum, mp_sum
@ -933,7 +933,7 @@ CONTAINS
INTEGER :: i, j, rank
!
CALL descla_init( desc, nsiz, nsiz, np_ortho, me_ortho, ortho_comm, ortho_comm_id )
CALL descla_init( desc, nsiz, nsiz, np_ortho, me_ortho, ortho_comm, ortho_cntx, ortho_comm_id )
!
nx = desc%nrcx
!

View File

@ -125,7 +125,7 @@ SUBROUTINE protate_wfc_gamma( npwx, npw, nstart, gstart, nbnd, psi, overlap, evc
USE control_flags, ONLY : gamma_only
USE mp_bands, ONLY : intra_bgrp_comm, nbgrp
USE mp_diag, ONLY : ortho_comm, np_ortho, me_ortho, ortho_comm_id,&
leg_ortho, ortho_parent_comm
leg_ortho, ortho_parent_comm, ortho_cntx
USE descriptors, ONLY : la_descriptor, descla_init
USE parallel_toolkit, ONLY : dsqmsym
USE mp, ONLY : mp_bcast, mp_root_sum, mp_sum, mp_barrier
@ -232,7 +232,7 @@ CONTAINS
INTEGER :: i, j, rank
INTEGER :: coor_ip( 2 )
!
CALL descla_init( desc, nsiz, nsiz, np_ortho, me_ortho, ortho_comm, ortho_comm_id )
CALL descla_init( desc, nsiz, nsiz, np_ortho, me_ortho, ortho_comm, ortho_cntx, ortho_comm_id )
!
nx = desc%nrcx
!
@ -240,7 +240,7 @@ CONTAINS
DO i = 0, desc%npr - 1
coor_ip( 1 ) = i
coor_ip( 2 ) = j
CALL descla_init( desc_ip(i+1,j+1), desc%n, desc%nx, np_ortho, coor_ip, ortho_comm, 1 )
CALL descla_init( desc_ip(i+1,j+1), desc%n, desc%nx, np_ortho, coor_ip, ortho_comm, ortho_cntx, 1 )
CALL GRID2D_RANK( 'R', desc%npr, desc%npc, i, j, rank )
rank_ip( i+1, j+1 ) = rank * leg_ortho
END DO

View File

@ -115,7 +115,7 @@ SUBROUTINE protate_wfc_k( npwx, npw, nstart, nbnd, npol, psi, overlap, evc, e )
USE kinds, ONLY : DP
USE mp_bands, ONLY : intra_bgrp_comm, nbgrp
USE mp_diag, ONLY : ortho_comm, np_ortho, me_ortho, ortho_comm_id,&
leg_ortho, ortho_parent_comm
leg_ortho, ortho_parent_comm, ortho_cntx
USE descriptors, ONLY : descla_init , la_descriptor
USE parallel_toolkit, ONLY : zsqmher
USE mp, ONLY : mp_bcast, mp_root_sum, mp_sum, mp_barrier
@ -231,7 +231,7 @@ CONTAINS
INTEGER :: i, j, rank
INTEGER :: coor_ip( 2 )
!
CALL descla_init( desc, nsiz, nsiz, np_ortho, me_ortho, ortho_comm, ortho_comm_id )
CALL descla_init( desc, nsiz, nsiz, np_ortho, me_ortho, ortho_comm, ortho_cntx, ortho_comm_id )
!
nx = desc%nrcx
!
@ -240,7 +240,7 @@ CONTAINS
coor_ip( 1 ) = i
coor_ip( 2 ) = j
CALL descla_init( desc_ip(i+1,j+1), desc%n, desc%nx, &
np_ortho, coor_ip, ortho_comm, 1 )
np_ortho, coor_ip, ortho_comm, ortho_cntx, 1 )
CALL GRID2D_RANK( 'R', desc%npr, desc%npc, i, j, rank )
rank_ip( i+1, j+1 ) = rank * leg_ortho
END DO