Clean up cuda-runtime-wrappers API.
Do not return error code, instead return created resource handles or void. Error reporting is done by the library function. Reviewed By: herhut Differential Revision: https://reviews.llvm.org/D84660
This commit is contained in:
parent
22ec861d28
commit
c64c04bbaa
|
@ -39,7 +39,7 @@ static constexpr const char *kGpuModuleLoadName = "mgpuModuleLoad";
|
|||
static constexpr const char *kGpuModuleGetFunctionName =
|
||||
"mgpuModuleGetFunction";
|
||||
static constexpr const char *kGpuLaunchKernelName = "mgpuLaunchKernel";
|
||||
static constexpr const char *kGpuGetStreamHelperName = "mgpuGetStreamHelper";
|
||||
static constexpr const char *kGpuStreamCreateName = "mgpuStreamCreate";
|
||||
static constexpr const char *kGpuStreamSynchronizeName =
|
||||
"mgpuStreamSynchronize";
|
||||
static constexpr const char *kGpuMemHostRegisterName = "mgpuMemHostRegister";
|
||||
|
@ -100,12 +100,6 @@ private:
|
|||
getLLVMDialect(), module.getDataLayout().getPointerSizeInBits());
|
||||
}
|
||||
|
||||
LLVM::LLVMType getGpuRuntimeResultType() {
|
||||
// This is declared as an enum in both CUDA and ROCm (HIP), but helpers
|
||||
// use i32.
|
||||
return getInt32Type();
|
||||
}
|
||||
|
||||
// Allocate a void pointer on the stack.
|
||||
Value allocatePointer(OpBuilder &builder, Location loc) {
|
||||
auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
|
||||
|
@ -168,27 +162,21 @@ void GpuLaunchFuncToGpuRuntimeCallsPass::declareGpuRuntimeFunctions(
|
|||
if (!module.lookupSymbol(kGpuModuleLoadName)) {
|
||||
builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, kGpuModuleLoadName,
|
||||
LLVM::LLVMType::getFunctionTy(
|
||||
getGpuRuntimeResultType(),
|
||||
{
|
||||
getPointerPointerType(), /* CUmodule *module */
|
||||
getPointerType() /* void *cubin */
|
||||
},
|
||||
/*isVarArg=*/false));
|
||||
LLVM::LLVMType::getFunctionTy(getPointerType(),
|
||||
{getPointerType()}, /* void *cubin */
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
if (!module.lookupSymbol(kGpuModuleGetFunctionName)) {
|
||||
// The helper uses void* instead of CUDA's opaque CUmodule and
|
||||
// CUfunction, or ROCm (HIP)'s opaque hipModule_t and hipFunction_t.
|
||||
builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, kGpuModuleGetFunctionName,
|
||||
LLVM::LLVMType::getFunctionTy(
|
||||
getGpuRuntimeResultType(),
|
||||
{
|
||||
getPointerPointerType(), /* void **function */
|
||||
getPointerType(), /* void *module */
|
||||
getPointerType() /* char *name */
|
||||
},
|
||||
/*isVarArg=*/false));
|
||||
LLVM::LLVMType::getFunctionTy(getPointerType(),
|
||||
{
|
||||
getPointerType(), /* void *module */
|
||||
getPointerType() /* char *name */
|
||||
},
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
if (!module.lookupSymbol(kGpuLaunchKernelName)) {
|
||||
// Other than the CUDA or ROCm (HIP) api, the wrappers use uintptr_t to
|
||||
|
@ -198,7 +186,7 @@ void GpuLaunchFuncToGpuRuntimeCallsPass::declareGpuRuntimeFunctions(
|
|||
builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, kGpuLaunchKernelName,
|
||||
LLVM::LLVMType::getFunctionTy(
|
||||
getGpuRuntimeResultType(),
|
||||
getVoidType(),
|
||||
{
|
||||
getPointerType(), /* void* f */
|
||||
getIntPtrType(), /* intptr_t gridXDim */
|
||||
|
@ -214,18 +202,18 @@ void GpuLaunchFuncToGpuRuntimeCallsPass::declareGpuRuntimeFunctions(
|
|||
},
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
if (!module.lookupSymbol(kGpuGetStreamHelperName)) {
|
||||
if (!module.lookupSymbol(kGpuStreamCreateName)) {
|
||||
// Helper function to get the current GPU compute stream. Uses void*
|
||||
// instead of CUDA's opaque CUstream, or ROCm (HIP)'s opaque hipStream_t.
|
||||
builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, kGpuGetStreamHelperName,
|
||||
loc, kGpuStreamCreateName,
|
||||
LLVM::LLVMType::getFunctionTy(getPointerType(), /*isVarArg=*/false));
|
||||
}
|
||||
if (!module.lookupSymbol(kGpuStreamSynchronizeName)) {
|
||||
builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, kGpuStreamSynchronizeName,
|
||||
LLVM::LLVMType::getFunctionTy(getGpuRuntimeResultType(),
|
||||
getPointerType() /* CUstream stream */,
|
||||
LLVM::LLVMType::getFunctionTy(getVoidType(),
|
||||
{getPointerType()}, /* void *stream */
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
if (!module.lookupSymbol(kGpuMemHostRegisterName)) {
|
||||
|
@ -365,17 +353,13 @@ Value GpuLaunchFuncToGpuRuntimeCallsPass::generateKernelNameConstant(
|
|||
// hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
|
||||
//
|
||||
// %0 = call %binarygetter
|
||||
// %1 = alloca sizeof(void*)
|
||||
// call %moduleLoad(%2, %1)
|
||||
// %2 = alloca sizeof(void*)
|
||||
// %3 = load %1
|
||||
// %4 = <see generateKernelNameConstant>
|
||||
// call %moduleGetFunction(%2, %3, %4)
|
||||
// %5 = call %getStreamHelper()
|
||||
// %6 = load %2
|
||||
// %7 = <see setupParamsArray>
|
||||
// call %launchKernel(%6, <launchOp operands 0..5>, 0, %5, %7, nullptr)
|
||||
// call %streamSynchronize(%5)
|
||||
// %1 = call %moduleLoad(%0)
|
||||
// %2 = <see generateKernelNameConstant>
|
||||
// %3 = call %moduleGetFunction(%1, %2)
|
||||
// %4 = call %streamCreate()
|
||||
// %5 = <see setupParamsArray>
|
||||
// call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
|
||||
// call %streamSynchronize(%4)
|
||||
void GpuLaunchFuncToGpuRuntimeCallsPass::translateGpuLaunchCalls(
|
||||
mlir::gpu::LaunchFuncOp launchOp) {
|
||||
OpBuilder builder(launchOp);
|
||||
|
@ -405,36 +389,30 @@ void GpuLaunchFuncToGpuRuntimeCallsPass::translateGpuLaunchCalls(
|
|||
|
||||
// Emit the load module call to load the module data. Error checking is done
|
||||
// in the called helper function.
|
||||
auto gpuModule = allocatePointer(builder, loc);
|
||||
auto gpuModuleLoad =
|
||||
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(kGpuModuleLoadName);
|
||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getGpuRuntimeResultType()},
|
||||
builder.getSymbolRefAttr(gpuModuleLoad),
|
||||
ArrayRef<Value>{gpuModule, data});
|
||||
auto module = builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getPointerType()},
|
||||
builder.getSymbolRefAttr(gpuModuleLoad), ArrayRef<Value>{data});
|
||||
// Get the function from the module. The name corresponds to the name of
|
||||
// the kernel function.
|
||||
auto gpuOwningModuleRef =
|
||||
builder.create<LLVM::LoadOp>(loc, getPointerType(), gpuModule);
|
||||
auto kernelName = generateKernelNameConstant(
|
||||
launchOp.getKernelModuleName(), launchOp.getKernelName(), loc, builder);
|
||||
auto gpuFunction = allocatePointer(builder, loc);
|
||||
auto gpuModuleGetFunction =
|
||||
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(kGpuModuleGetFunctionName);
|
||||
builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getGpuRuntimeResultType()},
|
||||
builder.getSymbolRefAttr(gpuModuleGetFunction),
|
||||
ArrayRef<Value>{gpuFunction, gpuOwningModuleRef, kernelName});
|
||||
// Grab the global stream needed for execution.
|
||||
auto gpuGetStreamHelper =
|
||||
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(kGpuGetStreamHelperName);
|
||||
auto gpuStream = builder.create<LLVM::CallOp>(
|
||||
auto function = builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getPointerType()},
|
||||
builder.getSymbolRefAttr(gpuGetStreamHelper), ArrayRef<Value>{});
|
||||
builder.getSymbolRefAttr(gpuModuleGetFunction),
|
||||
ArrayRef<Value>{module.getResult(0), kernelName});
|
||||
// Grab the global stream needed for execution.
|
||||
auto gpuStreamCreate =
|
||||
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(kGpuStreamCreateName);
|
||||
auto stream = builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getPointerType()},
|
||||
builder.getSymbolRefAttr(gpuStreamCreate), ArrayRef<Value>{});
|
||||
// Invoke the function with required arguments.
|
||||
auto gpuLaunchKernel =
|
||||
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(kGpuLaunchKernelName);
|
||||
auto gpuFunctionRef =
|
||||
builder.create<LLVM::LoadOp>(loc, getPointerType(), gpuFunction);
|
||||
auto paramsArray = setupParamsArray(launchOp, builder);
|
||||
if (!paramsArray) {
|
||||
launchOp.emitOpError() << "cannot pass given parameters to the kernel";
|
||||
|
@ -443,21 +421,21 @@ void GpuLaunchFuncToGpuRuntimeCallsPass::translateGpuLaunchCalls(
|
|||
auto nullpointer =
|
||||
builder.create<LLVM::IntToPtrOp>(loc, getPointerPointerType(), zero);
|
||||
builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getGpuRuntimeResultType()},
|
||||
loc, ArrayRef<Type>{getVoidType()},
|
||||
builder.getSymbolRefAttr(gpuLaunchKernel),
|
||||
ArrayRef<Value>{gpuFunctionRef, launchOp.getOperand(0),
|
||||
ArrayRef<Value>{function.getResult(0), launchOp.getOperand(0),
|
||||
launchOp.getOperand(1), launchOp.getOperand(2),
|
||||
launchOp.getOperand(3), launchOp.getOperand(4),
|
||||
launchOp.getOperand(5), zero, /* sharedMemBytes */
|
||||
gpuStream.getResult(0), /* stream */
|
||||
stream.getResult(0), /* stream */
|
||||
paramsArray, /* kernel params */
|
||||
nullpointer /* extra */});
|
||||
// Sync on the stream to make it synchronous.
|
||||
auto gpuStreamSync =
|
||||
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(kGpuStreamSynchronizeName);
|
||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getGpuRuntimeResultType()},
|
||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
|
||||
builder.getSymbolRefAttr(gpuStreamSync),
|
||||
ArrayRef<Value>(gpuStream.getResult(0)));
|
||||
ArrayRef<Value>(stream.getResult(0)));
|
||||
launchOp.erase();
|
||||
}
|
||||
|
||||
|
|
|
@ -20,13 +20,11 @@ module attributes {gpu.container_module} {
|
|||
|
||||
// CHECK: %[[addressof:.*]] = llvm.mlir.addressof @[[global]]
|
||||
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index)
|
||||
// CHECK: %[[binary_ptr:.*]] = llvm.getelementptr %[[addressof]][%[[c0]], %[[c0]]]
|
||||
// CHECK: %[[binary:.*]] = llvm.getelementptr %[[addressof]][%[[c0]], %[[c0]]]
|
||||
// CHECK-SAME: -> !llvm<"i8*">
|
||||
// CHECK: %[[module_ptr:.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
|
||||
// CHECK: llvm.call @mgpuModuleLoad(%[[module_ptr]], %[[binary_ptr]]) : (!llvm<"i8**">, !llvm<"i8*">) -> !llvm.i32
|
||||
// CHECK: %[[func_ptr:.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
|
||||
// CHECK: llvm.call @mgpuModuleGetFunction(%[[func_ptr]], {{.*}}, {{.*}}) : (!llvm<"i8**">, !llvm<"i8*">, !llvm<"i8*">) -> !llvm.i32
|
||||
// CHECK: llvm.call @mgpuGetStreamHelper
|
||||
// CHECK: %[[module:.*]] = llvm.call @mgpuModuleLoad(%[[binary]]) : (!llvm<"i8*">) -> !llvm<"i8*">
|
||||
// CHECK: %[[func:.*]] = llvm.call @mgpuModuleGetFunction(%[[module]], {{.*}}) : (!llvm<"i8*">, !llvm<"i8*">) -> !llvm<"i8*">
|
||||
// CHECK: llvm.call @mgpuStreamCreate
|
||||
// CHECK: llvm.call @mgpuLaunchKernel
|
||||
// CHECK: llvm.call @mgpuStreamSynchronize
|
||||
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = @kernel_module::@kernel }
|
||||
|
|
|
@ -21,54 +21,50 @@
|
|||
|
||||
#include "cuda.h"
|
||||
|
||||
namespace {
|
||||
int32_t reportErrorIfAny(CUresult result, const char *where) {
|
||||
if (result != CUDA_SUCCESS) {
|
||||
llvm::errs() << "CUDA failed with " << result << " in " << where << "\n";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // anonymous namespace
|
||||
#define CUDA_REPORT_IF_ERROR(expr) \
|
||||
[](CUresult result) { \
|
||||
if (!result) \
|
||||
return; \
|
||||
const char *name = nullptr; \
|
||||
cuGetErrorName(result, &name); \
|
||||
if (!name) \
|
||||
name = "<unknown>"; \
|
||||
llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \
|
||||
}(expr)
|
||||
|
||||
extern "C" int32_t mgpuModuleLoad(void **module, void *data) {
|
||||
int32_t err = reportErrorIfAny(
|
||||
cuModuleLoadData(reinterpret_cast<CUmodule *>(module), data),
|
||||
"ModuleLoad");
|
||||
return err;
|
||||
extern "C" CUmodule mgpuModuleLoad(void *data) {
|
||||
CUmodule module = nullptr;
|
||||
CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
|
||||
return module;
|
||||
}
|
||||
|
||||
extern "C" int32_t mgpuModuleGetFunction(void **function, void *module,
|
||||
const char *name) {
|
||||
return reportErrorIfAny(
|
||||
cuModuleGetFunction(reinterpret_cast<CUfunction *>(function),
|
||||
reinterpret_cast<CUmodule>(module), name),
|
||||
"GetFunction");
|
||||
extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) {
|
||||
CUfunction function = nullptr;
|
||||
CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name));
|
||||
return function;
|
||||
}
|
||||
|
||||
// The wrapper uses intptr_t instead of CUDA's unsigned int to match
|
||||
// the type of MLIR's index type. This avoids the need for casts in the
|
||||
// generated MLIR code.
|
||||
extern "C" int32_t mgpuLaunchKernel(void *function, intptr_t gridX,
|
||||
intptr_t gridY, intptr_t gridZ,
|
||||
intptr_t blockX, intptr_t blockY,
|
||||
intptr_t blockZ, int32_t smem, void *stream,
|
||||
void **params, void **extra) {
|
||||
return reportErrorIfAny(
|
||||
cuLaunchKernel(reinterpret_cast<CUfunction>(function), gridX, gridY,
|
||||
gridZ, blockX, blockY, blockZ, smem,
|
||||
reinterpret_cast<CUstream>(stream), params, extra),
|
||||
"LaunchKernel");
|
||||
extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX,
|
||||
intptr_t gridY, intptr_t gridZ,
|
||||
intptr_t blockX, intptr_t blockY,
|
||||
intptr_t blockZ, int32_t smem, CUstream stream,
|
||||
void **params, void **extra) {
|
||||
CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX,
|
||||
blockY, blockZ, smem, stream, params,
|
||||
extra));
|
||||
}
|
||||
|
||||
extern "C" void *mgpuGetStreamHelper() {
|
||||
CUstream stream;
|
||||
reportErrorIfAny(cuStreamCreate(&stream, CU_STREAM_DEFAULT), "StreamCreate");
|
||||
extern "C" CUstream mgpuStreamCreate() {
|
||||
CUstream stream = nullptr;
|
||||
CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
|
||||
return stream;
|
||||
}
|
||||
|
||||
extern "C" int32_t mgpuStreamSynchronize(void *stream) {
|
||||
return reportErrorIfAny(
|
||||
cuStreamSynchronize(reinterpret_cast<CUstream>(stream)), "StreamSync");
|
||||
extern "C" void mgpuStreamSynchronize(CUstream stream) {
|
||||
CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream));
|
||||
}
|
||||
|
||||
/// Helper functions for writing mlir example code
|
||||
|
@ -76,17 +72,16 @@ extern "C" int32_t mgpuStreamSynchronize(void *stream) {
|
|||
// Allows to register byte array with the CUDA runtime. Helpful until we have
|
||||
// transfer functions implemented.
|
||||
extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
|
||||
reportErrorIfAny(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0),
|
||||
"MemHostRegister");
|
||||
CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0));
|
||||
}
|
||||
|
||||
// Allows to register a MemRef with the CUDA runtime. Initializes array with
|
||||
// value. Helpful until we have transfer functions implemented.
|
||||
template <typename T>
|
||||
void mgpuMemHostRegisterMemRef(const DynamicMemRefType<T> &mem_ref, T value) {
|
||||
llvm::SmallVector<int64_t, 4> denseStrides(mem_ref.rank);
|
||||
llvm::ArrayRef<int64_t> sizes(mem_ref.sizes, mem_ref.rank);
|
||||
llvm::ArrayRef<int64_t> strides(mem_ref.strides, mem_ref.rank);
|
||||
void mgpuMemHostRegisterMemRef(const DynamicMemRefType<T> &memRef, T value) {
|
||||
llvm::SmallVector<int64_t, 4> denseStrides(memRef.rank);
|
||||
llvm::ArrayRef<int64_t> sizes(memRef.sizes, memRef.rank);
|
||||
llvm::ArrayRef<int64_t> strides(memRef.strides, memRef.rank);
|
||||
|
||||
std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
|
||||
std::multiplies<int64_t>());
|
||||
|
@ -98,17 +93,17 @@ void mgpuMemHostRegisterMemRef(const DynamicMemRefType<T> &mem_ref, T value) {
|
|||
denseStrides.back() = 1;
|
||||
assert(strides == llvm::makeArrayRef(denseStrides));
|
||||
|
||||
auto *pointer = mem_ref.data + mem_ref.offset;
|
||||
auto *pointer = memRef.data + memRef.offset;
|
||||
std::fill_n(pointer, count, value);
|
||||
mgpuMemHostRegister(pointer, count * sizeof(T));
|
||||
}
|
||||
|
||||
extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) {
|
||||
UnrankedMemRefType<float> mem_ref = {rank, ptr};
|
||||
mgpuMemHostRegisterMemRef(DynamicMemRefType<float>(mem_ref), 1.23f);
|
||||
UnrankedMemRefType<float> memRef = {rank, ptr};
|
||||
mgpuMemHostRegisterMemRef(DynamicMemRefType<float>(memRef), 1.23f);
|
||||
}
|
||||
|
||||
extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) {
|
||||
UnrankedMemRefType<int32_t> mem_ref = {rank, ptr};
|
||||
mgpuMemHostRegisterMemRef(DynamicMemRefType<int32_t>(mem_ref), 123);
|
||||
UnrankedMemRefType<int32_t> memRef = {rank, ptr};
|
||||
mgpuMemHostRegisterMemRef(DynamicMemRefType<int32_t>(memRef), 123);
|
||||
}
|
||||
|
|
|
@ -21,56 +21,52 @@
|
|||
|
||||
#include "hip/hip_runtime.h"
|
||||
|
||||
namespace {
|
||||
int32_t reportErrorIfAny(hipError_t result, const char *where) {
|
||||
if (result != hipSuccess) {
|
||||
llvm::errs() << "HIP failed with " << result << " in " << where << "\n";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // anonymous namespace
|
||||
#define HIP_REPORT_IF_ERROR(expr) \
|
||||
[](hipError_t result) { \
|
||||
if (!result) \
|
||||
return; \
|
||||
const char *name = nullptr; \
|
||||
hipGetErrorName(result, &name); \
|
||||
if (!name) \
|
||||
name = "<unknown>"; \
|
||||
llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \
|
||||
}(expr)
|
||||
|
||||
extern "C" int32_t mgpuModuleLoad(void **module, void *data) {
|
||||
int32_t err = reportErrorIfAny(
|
||||
hipModuleLoadData(reinterpret_cast<hipModule_t *>(module), data),
|
||||
"ModuleLoad");
|
||||
return err;
|
||||
extern "C" hipModule_t mgpuModuleLoad(void *data) {
|
||||
hipModule_t module = nullptr;
|
||||
HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
|
||||
return module;
|
||||
}
|
||||
|
||||
extern "C" int32_t mgpuModuleGetFunction(void **function, void *module,
|
||||
const char *name) {
|
||||
return reportErrorIfAny(
|
||||
hipModuleGetFunction(reinterpret_cast<hipFunction_t *>(function),
|
||||
reinterpret_cast<hipModule_t>(module), name),
|
||||
"GetFunction");
|
||||
extern "C" hipFunction_t mgpuModuleGetFunction(hipModule_t module,
|
||||
const char *name) {
|
||||
hipFunction_t function = nullptr;
|
||||
HIP_REPORT_IF_ERROR(hipModuleGetFunction(&function, module, name));
|
||||
return function;
|
||||
}
|
||||
|
||||
// The wrapper uses intptr_t instead of ROCM's unsigned int to match
|
||||
// the type of MLIR's index type. This avoids the need for casts in the
|
||||
// generated MLIR code.
|
||||
extern "C" int32_t mgpuLaunchKernel(void *function, intptr_t gridX,
|
||||
intptr_t gridY, intptr_t gridZ,
|
||||
intptr_t blockX, intptr_t blockY,
|
||||
intptr_t blockZ, int32_t smem, void *stream,
|
||||
void **params, void **extra) {
|
||||
return reportErrorIfAny(
|
||||
hipModuleLaunchKernel(reinterpret_cast<hipFunction_t>(function), gridX,
|
||||
gridY, gridZ, blockX, blockY, blockZ, smem,
|
||||
reinterpret_cast<hipStream_t>(stream), params,
|
||||
extra),
|
||||
"LaunchKernel");
|
||||
extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX,
|
||||
intptr_t gridY, intptr_t gridZ,
|
||||
intptr_t blockX, intptr_t blockY,
|
||||
intptr_t blockZ, int32_t smem,
|
||||
hipStream_t stream, void **params,
|
||||
void **extra) {
|
||||
HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ,
|
||||
blockX, blockY, blockZ, smem,
|
||||
stream, params, extra));
|
||||
}
|
||||
|
||||
extern "C" void *mgpuGetStreamHelper() {
|
||||
hipStream_t stream;
|
||||
reportErrorIfAny(hipStreamCreate(&stream), "StreamCreate");
|
||||
extern "C" void *mgpuStreamCreate() {
|
||||
hipStream_t stream = nullptr;
|
||||
HIP_REPORT_IF_ERROR(hipStreamCreate(&stream));
|
||||
return stream;
|
||||
}
|
||||
|
||||
extern "C" int32_t mgpuStreamSynchronize(void *stream) {
|
||||
return reportErrorIfAny(
|
||||
hipStreamSynchronize(reinterpret_cast<hipStream_t>(stream)),
|
||||
"StreamSync");
|
||||
extern "C" void mgpuStreamSynchronize(hipStream_t stream) {
|
||||
return HIP_REPORT_IF_ERROR(hipStreamSynchronize(stream));
|
||||
}
|
||||
|
||||
/// Helper functions for writing mlir example code
|
||||
|
@ -78,8 +74,8 @@ extern "C" int32_t mgpuStreamSynchronize(void *stream) {
|
|||
// Allows to register byte array with the ROCM runtime. Helpful until we have
|
||||
// transfer functions implemented.
|
||||
extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
|
||||
reportErrorIfAny(hipHostRegister(ptr, sizeBytes, /*flags=*/0),
|
||||
"MemHostRegister");
|
||||
HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0),
|
||||
"MemHostRegister");
|
||||
}
|
||||
|
||||
// Allows to register a MemRef with the ROCM runtime. Initializes array with
|
||||
|
@ -120,8 +116,8 @@ extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) {
|
|||
|
||||
template <typename T>
|
||||
void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) {
|
||||
reportErrorIfAny(hipSetDevice(0), "hipSetDevice");
|
||||
reportErrorIfAny(
|
||||
HIP_REPORT_IF_ERROR(hipSetDevice(0), "hipSetDevice");
|
||||
HIP_REPORT_IF_ERROR(
|
||||
hipHostGetDevicePointer((void **)devicePtr, hostPtr, /*flags=*/0),
|
||||
"hipHostGetDevicePointer");
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue