Merge branch 'develop_omp5' into 'develop_omp5'

Move from non-standard target variant dispatch to standard dispatch construct for MKL calls

See merge request QEF/q-e!2206
This commit is contained in:
Ivan Carnimeo 2024-01-11 18:29:45 +00:00
commit c87bd130df
2 changed files with 71 additions and 40 deletions

View File

@ -122,25 +122,33 @@
IF (isign < 0) THEN
IF (is_inplace) THEN
!$omp target variant dispatch use_device_ptr(c)
!$omp target data use_device_addr(c)
!$omp dispatch
dfti_status = DftiComputeForward(hand(ip)%desc, c )
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
ELSE
!$omp target variant dispatch use_device_ptr(c, cout)
!$omp target data use_device_addr(c,cout)
!$omp dispatch
dfti_status = DftiComputeForward(hand(ip)%desc, c, cout )
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
ENDIF
IF(dfti_status /= 0) CALL fftx_error__(' cft_1z GPU ',&
' stopped in DftiComputeForward '// DftiErrorMessage(dfti_status), dfti_status )
ELSE IF (isign > 0) THEN
IF (is_inplace) THEN
!$omp target variant dispatch use_device_ptr(c)
!$omp target data use_device_addr(c)
!$omp dispatch
dfti_status = DftiComputeBackward(hand(ip)%desc, c)
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
ELSE
!$omp target variant dispatch use_device_ptr(c, cout)
!$omp target data use_device_addr(c,cout)
!$omp dispatch
dfti_status = DftiComputeBackward(hand(ip)%desc, c, cout )
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
ENDIF
IF(dfti_status /= 0) CALL fftx_error__(' cft_1z GPU ',&
' stopped in DftiComputeBackward '// DftiErrorMessage(dfti_status), dfti_status )
@ -220,9 +228,9 @@
!IF(dfti_status /= 0) &
! CALL fftx_error__(' cft_1z ',' stopped in DFTI_THREAD_LIMIT ', dfti_status )
!$omp target variant dispatch
!$omp dispatch
dfti_status = DftiCommitDescriptor(hand( icurrent )%desc)
!$omp end target variant dispatch
!$omp end dispatch
IF(dfti_status /= 0) CALL fftx_error__(' cft_1z ',&
' stopped in DftiCommitDescriptor '// DftiErrorMessage(dfti_status), dfti_status )
@ -318,17 +326,21 @@
IF( isign < 0 ) THEN
!
!$omp target variant dispatch use_device_ptr(r)
!$omp target data use_device_addr(r)
!$omp dispatch
dfti_status = DftiComputeForward(hand(ip)%desc, r(:))
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
IF(dfti_status /= 0) CALL fftx_error__(' cft_2xy GPU ',&
' stopped in DftiComputeForward '// DftiErrorMessage(dfti_status), dfti_status )
!
ELSE IF( isign > 0 ) THEN
!
!$omp target variant dispatch use_device_ptr(r)
!$omp target data use_device_addr(r)
!$omp dispatch
dfti_status = DftiComputeBackward(hand(ip)%desc, r(:))
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
IF(dfti_status /= 0) CALL fftx_error__(' cft_2xy GPU ',&
' stopped in DftiComputeBackward '// DftiErrorMessage(dfti_status), dfti_status )
!
@ -400,9 +412,9 @@
IF(dfti_status /= 0) CALL fftx_error__(' cft_2xy GPU',&
' stopped in DFTI_BACKWARD_SCALE '// DftiErrorMessage(dfti_status), dfti_status )
!$omp target variant dispatch
!$omp dispatch
dfti_status = DftiCommitDescriptor(hand( icurrent )%desc)
!$omp end target variant dispatch
!$omp end dispatch
IF(dfti_status /= 0) CALL fftx_error__(' cft_2xy GPU',&
' stopped in DftiCommitDescriptor '// DftiErrorMessage(dfti_status), dfti_status )
@ -480,17 +492,21 @@
IF( isign < 0 ) THEN
!
!$omp target variant dispatch use_device_ptr(f)
!$omp target data use_device_addr(f)
!$omp dispatch
dfti_status = DftiComputeForward(hand(ip)%desc, f(1:))
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
IF(dfti_status /= 0) CALL fftx_error__(' cfft3d GPU ',&
' stopped in DftiComputeForward '// DftiErrorMessage(dfti_status), dfti_status )
!
ELSE IF( isign > 0 ) THEN
!
!$omp target variant dispatch use_device_ptr(f)
!$omp target data use_device_addr(f)
!$omp dispatch
dfti_status = DftiComputeBackward(hand(ip)%desc, f(1:))
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
IF(dfti_status /= 0) CALL fftx_error__(' cfft3d GPU ',&
' stopped in DftiComputeBackward '// DftiErrorMessage(dfti_status), dfti_status )
!
@ -567,9 +583,9 @@
IF(dfti_status /= 0) CALL fftx_error__(' cfft3d GPU',&
' stopped in DFTI_BACKWARD_SCALE '// DftiErrorMessage(dfti_status), dfti_status )
!$omp target variant dispatch
!$omp dispatch
dfti_status = DftiCommitDescriptor(hand(icurrent)%desc)
!$omp end target variant dispatch
!$omp end dispatch
IF(dfti_status /= 0) CALL fftx_error__(' cfft3d GPU',&
' stopped in DftiCommitDescriptor '// DftiErrorMessage(dfti_status), dfti_status )

View File

@ -31,9 +31,11 @@ SUBROUTINE MYDGER ( M, N, ALPHA, X, INCX, Y, INCY, A, LDA )
CALL DGER ( M, N, ALPHA, X, INCX, Y, INCY, A, LDA )
#elif defined(__OPENMP_GPU)
#if defined(__ONEMKL)
!$omp target variant dispatch use_device_ptr(A, X, Y)
!$omp target data use_device_addr(A, X, Y)
!$omp dispatch
CALL DGER ( M, N, ALPHA, X, INCX, Y, INCY, A, LDA )
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
#elif defined(__ROCBLAS)
CALL rocblas_dger( M, N, ALPHA, X, INCX, Y, INCY, A, LDA )
#endif
@ -84,9 +86,11 @@ SUBROUTINE MYDGEMM( TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC
CALL cublasdgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
#elif defined(__OPENMP_GPU)
#if defined(__ONEMKL)
!$omp target variant dispatch use_device_ptr(A, B, C)
!$omp target data use_device_addr(A, B, C)
!$omp dispatch
CALL dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
#elif defined(__ROCBLAS)
CALL rocblas_dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
#endif
@ -117,7 +121,8 @@ SUBROUTINE MYZGEMM( TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC
CALL cublaszgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
#else
#if defined(__ONEMKL)
!$omp target variant dispatch use_device_ptr(A, B, C)
!$omp target data use_device_addr(A, B, C)
!$omp dispatch
#endif
#if defined(__ROCBLAS)
CALL rocblas_zgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
@ -125,7 +130,8 @@ SUBROUTINE MYZGEMM( TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC
CALL zgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
#endif
#if defined(__ONEMKL)
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
#endif
#endif
@ -159,9 +165,11 @@ SUBROUTINE MYDGER2 ( M, N, ALPHA, X, INCX, Y, INCY, A, LDA, OMP_OFFLOAD )
#elif defined(__OPENMP_GPU)
#if defined(__ONEMKL)
IF (OMP_OFFLOAD) THEN
!$omp target variant dispatch use_device_ptr(A, X, Y)
!$omp target data use_device_addr(A, B, C)
!$omp dispatch
CALL DGER ( M, N, ALPHA, X, INCX, Y, INCY, A, LDA )
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
ELSE
CALL DGER ( M, N, ALPHA, X, INCX, Y, INCY, A, LDA )
ENDIF
@ -200,9 +208,11 @@ SUBROUTINE MYDGEMM2( TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LD
CALL cublasdgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
#elif defined(__ONEMKL)
IF (OMP_OFFLOAD) THEN
!$omp target variant dispatch use_device_ptr(A, B, C)
!$omp target data use_device_addr(A, B, C)
!$omp dispatch
CALL dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
ELSE
CALL dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
ENDIF
@ -237,9 +247,11 @@ SUBROUTINE MYZGEMM2( TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LD
CALL cublaszgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
#elif defined(__ONEMKL)
IF (OMP_OFFLOAD) THEN
!$omp target variant dispatch use_device_ptr(A, B, C)
!$omp target data use_device_addr(A, B, C)
!$omp dispatch
CALL zgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
ELSE
CALL zgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
ENDIF
@ -270,7 +282,6 @@ SUBROUTINE MYDGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#endif
END SUBROUTINE MYDGEMV
SUBROUTINE MYZGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#if defined(__CUDA)
use cudafor
@ -309,7 +320,8 @@ SUBROUTINE MYDGEMV2(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
CALL cublasdgemv(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#else
#if defined(__ONEMKL)
!$omp target variant dispatch use_device_ptr(A, X, Y)
!$omp target data use_device_addr(A, X, Y)
!$omp dispatch
#endif
#if defined(__ROCBLAS)
CALL rocblas_dgemv(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
@ -317,7 +329,8 @@ SUBROUTINE MYDGEMV2(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
CALL dgemv(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#endif
#if defined(__ONEMKL)
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
#endif
#endif
END SUBROUTINE MYDGEMV2
@ -345,7 +358,8 @@ SUBROUTINE MYZGEMV2(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
CALL cublaszgemv(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#else
#if defined(__ONEMKL)
!$omp target variant dispatch use_device_ptr(A, X, Y)
!$omp target data use_device_addr(A, X, Y)
!$omp dispatch
#endif
#if defined(__ROCBLAS)
CALL rocblas_zgemv(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
@ -353,7 +367,8 @@ SUBROUTINE MYZGEMV2(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
CALL zgemv(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#endif
#if defined(__ONEMKL)
!$omp end target variant dispatch
!$omp end dispatch
!$omp end target data
#endif
#endif
END SUBROUTINE MYZGEMV2