Fix full double precision GPU complex code.

Fix cuda_inverse complex double inversion.
Remove warning in CMakeList.txt.
This commit is contained in:
Ye Luo 2017-01-21 00:32:24 -06:00
parent 6dd24932e1
commit ae92477389
2 changed files with 10 additions and 11 deletions

View File

@ -83,11 +83,6 @@ IF(QMC_CUDA)
IF(QMC_MIXED_PRECISION)
SET(CUDA_PRECISION float)
ELSE(QMC_MIXED_PRECISION)
IF(QMC_COMPLEX)
MESSAGE(" WARNING: Complex GPU code is not recommended")
MESSAGE(" for the chosen CUDA base precision.")
MESSAGE(" Calculation results will be wrong.")
ENDIF()
SET(CUDA_PRECISION double)
ENDIF(QMC_MIXED_PRECISION)
MESSAGE(" Base precision = ${OHMMS_PRECISION}")

View File

@ -334,15 +334,19 @@ cublas_inverse (cublasHandle_t handle,
int *infoArray;
callAndCheckError( cudaMalloc((void**) &infoArray, numMats * sizeof(int)), __LINE__ );
// (i) call cublas functions to do inversion
// LU decomposition
callAndCheckError( cublasZgetrfBatched( handle, N, (cuDoubleComplex**)Alist_d, rowStride, NULL, infoArray, numMats), __LINE__ );
// (i) copy all the elements of Alist to AWorklist
dim3 dimBlockConvert (CONVERT_BS);
dim3 dimGridConvert ((N*rowStride + (CONVERT_BS-1)) / CONVERT_BS, numMats);
convert_complex<cuDoubleComplex, double, cuDoubleComplex> <<< dimGridConvert, dimBlockConvert >>> ((cuDoubleComplex**)AWorklist_d, (cuDoubleComplex**)Alist_d, N*rowStride);
// (ii) call cublas to do matrix inversion
// LU decomposition
callAndCheckError( cublasZgetrfBatched( handle, N, (cuDoubleComplex**)AWorklist_d, rowStride, NULL, infoArray, numMats), __LINE__ );
// Inversion
#if (CUDA_VERSION >= 6050)
callAndCheckError( cublasZgetriBatched( handle, N, (const cuDoubleComplex**)Alist_d, rowStride, NULL, (cuDoubleComplex**)Ainvlist_d, rowStride, infoArray, numMats), __LINE__ );
callAndCheckError( cublasZgetriBatched( handle, N, (const cuDoubleComplex**)AWorklist_d, rowStride, NULL, (cuDoubleComplex**)Ainvlist_d, rowStride, infoArray, numMats), __LINE__ );
#else
callAndCheckError( cublasZgetriBatched( handle, N, (cuDoubleComplex**)Alist_d, rowStride, NULL, (cuDoubleComplex**)Ainvlist_d, rowStride, infoArray, numMats), __LINE__ );
callAndCheckError( cublasZgetriBatched( handle, N, (cuDoubleComplex**)AWorklist_d, rowStride, NULL, (cuDoubleComplex**)Ainvlist_d, rowStride, infoArray, numMats), __LINE__ );
#endif
cudaDeviceSynchronize();