pw2wannier90: Refactor write_plot. Use pointer to remove reduce_unk branching.

This commit is contained in:
Jae-Mo Lihm 2022-02-26 23:47:48 +09:00
parent 4f8229ba59
commit 7e8e51595c
1 changed files with 22 additions and 70 deletions

View File

@ -4367,41 +4367,34 @@ SUBROUTINE write_plot
IMPLICIT NONE
!
INTEGER :: ik, npw, ibnd, ibnd1, ikevc, i1, j, spin
INTEGER :: ipol, nxxs
INTEGER :: nr1, nr2, nr3
!! Real space grid sizes for the wavefunction data written to file
CHARACTER*20 :: wfnname
! aam: 1/5/06: for writing smaller unk files
INTEGER :: n1by2,n2by2,n3by2,i,k,idx,pos
COMPLEX(DP),ALLOCATABLE :: evc_r(:, :), psic_small(:, :)
INTEGER ipol
!-------------------------------------------!
#if defined(__MPI)
INTEGER :: nxxs
COMPLEX(DP),ALLOCATABLE :: psic_all(:, :)
nxxs = dffts%nr1x * dffts%nr2x * dffts%nr3x
ALLOCATE(psic_all(nxxs, npol) )
#endif
INTEGER :: i,k,idx,pos
COMPLEX(DP), ALLOCATABLE :: evc_r(:, :)
COMPLEX(DP), POINTER :: psic_small(:, :)
COMPLEX(DP), ALLOCATABLE, TARGET :: psic_all(:, :)
!
CALL start_clock( 'write_unk' )
!
nxxs = dffts%nr1x * dffts%nr2x * dffts%nr3x
ALLOCATE(psic_all(nxxs, npol) )
ALLOCATE(evc_r(dffts%nnr, npol))
!
IF (reduce_unk) THEN
! TODO: Check if dffts%nr1 is divisible by 2
WRITE(stdout,'(3(a,i5))') 'nr1s =',dffts%nr1,'nr2s=',dffts%nr2,'nr3s=',dffts%nr3
n1by2=(dffts%nr1+1)/2
n2by2=(dffts%nr2+1)/2
n3by2=(dffts%nr3+1)/2
WRITE(stdout,'(3(a,i5))') 'n1by2=',n1by2,'n2by2=',n2by2,'n3by2=',n3by2
ALLOCATE(psic_small(n1by2*n2by2*n3by2, npol))
nr1 = (dffts%nr1+1)/2
nr2 = (dffts%nr2+1)/2
nr3 = (dffts%nr3+1)/2
WRITE(stdout,'(3(a,i5))') 'n1by2=', nr1, 'n2by2=', nr2, 'n3by2=', nr3
ALLOCATE(psic_small(nr1*nr2*nr3, npol))
psic_small = (0.0_DP, 0.0_DP)
nr1 = n1by2
nr2 = n2by2
nr3 = n3by2
ELSE
psic_small => psic_all
nr1 = dffts%nr1
nr2 = dffts%nr2
nr3 = dffts%nr3
@ -4450,12 +4443,16 @@ SUBROUTINE write_plot
CALL invfft('Wave', evc_r(:, ipol), dffts)
ENDDO
!
IF (reduce_unk) pos=0
#if defined(__MPI)
DO ipol = 1, npol
CALL gather_grid(dffts, evc_r(:, ipol), psic_all(:,ipol))
ENDDO
#else
psic_all(1:dffts%nnr, :) = evc_r(1:dffts%nnr, :)
#endif
!
IF (reduce_unk) THEN
pos = 0
DO k=1,dffts%nr3,2
DO j=1,dffts%nr2,2
DO i=1,dffts%nr1,2
@ -4468,61 +4465,18 @@ SUBROUTINE write_plot
ENDDO
ENDDO
ENDIF
IF (ionode) THEN
IF(wvfn_formatted) THEN
IF (reduce_unk) THEN
!
IF (ionode) THEN
IF(wvfn_formatted) THEN
DO ipol = 1, npol
WRITE(iun_plot,'(2ES20.10)') (psic_small(j, ipol), j = 1, nr1*nr2*nr3)
ENDDO
ELSE
DO ipol = 1, npol
WRITE(iun_plot,'(2ES20.10)') (psic_all(j, ipol), j = 1, nr1*nr2*nr3)
ENDDO
ENDIF
ELSE
IF (reduce_unk) THEN
DO ipol = 1, npol
WRITE(iun_plot) (psic_small(j, ipol), j = 1, nr1*nr2*nr3)
ENDDO
ELSE
DO ipol = 1, npol
WRITE(iun_plot) (psic_all(j, ipol), j = 1, nr1*nr2*nr3)
ENDDO
ENDIF
ENDIF
ENDIF
#else
IF (reduce_unk) THEN
DO k=1,dffts%nr3,2
DO j=1,dffts%nr2,2
DO i=1,dffts%nr1,2
idx = (k-1)*dffts%nr2*dffts%nr1 + (j-1)*dffts%nr1 + i
pos=pos+1
DO ipol = 1, npol
psic_small(pos,ipol) = evc_r(idx,ipol)
ENDDO
ENDDO
ENDDO
ENDDO
ENDIF
IF(wvfn_formatted) THEN
DO ipol = 1, npol
IF (reduce_unk) THEN
WRITE (iun_plot,'(2ES20.10)') (psic_small(j, ipol), j = 1, nr1*nr2*nr3)
ELSE
WRITE (iun_plot,'(2ES20.10)') (evc_r(j, ipol), j = 1, nr1*nr2*nr3)
ENDIF
ENDDO
ELSE
DO ipol = 1, npol
IF (reduce_unk) THEN
WRITE (iun_plot) (psic_small(j, ipol), j = 1, nr1*nr2*nr3)
ELSE
WRITE (iun_plot) (evc_r(j, ipol), j = 1, nr1*nr2*nr3)
ENDIF
ENDDO
ENDIF
#endif
ENDDO !ibnd
IF(ionode) CLOSE (unit=iun_plot)
@ -4533,9 +4487,7 @@ SUBROUTINE write_plot
DEALLOCATE(psic_small)
ENDIF
DEALLOCATE(evc_r)
#if defined(__MPI)
DEALLOCATE(psic_all)
#endif
!
WRITE(stdout, *) ' UNK written'
!