[mlir][spirv] Introduce OwningSPIRVModuleRef for ownership
Similar to OwningModuleRef, OwningSPIRVModuleRef signals ownership transfer clearly. This is useful for APIs like spirv::deserialize, where a spirv::ModuleOp is returned by deserializing SPIR-V binary module. This addresses the ASAN error as reported in https://bugs.llvm.org/show_bug.cgi?id=46272 Differential Revision: https://reviews.llvm.org/D81652
This commit is contained in:
parent
7bf299c8d8
commit
b80508703f
|
@ -0,0 +1,29 @@
|
|||
//===- SPIRVModule.h - SPIR-V Module Utilities ------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SPIRV_SPIRVMODULE_H
|
||||
#define MLIR_DIALECT_SPIRV_SPIRVMODULE_H
|
||||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/IR/OwningOpRefBase.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace spirv {
|
||||
|
||||
/// This class acts as an owning reference to a SPIR-V module, and will
|
||||
/// automatically destroy the held module on destruction if the held module
|
||||
/// is valid.
|
||||
class OwningSPIRVModuleRef : public OwningOpRefBase<spirv::ModuleOp> {
|
||||
public:
|
||||
using OwningOpRefBase<spirv::ModuleOp>::OwningOpRefBase;
|
||||
};
|
||||
|
||||
} // end namespace spirv
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_SPIRV_SPIRVMODULE_H
|
|
@ -22,6 +22,7 @@ class MLIRContext;
|
|||
|
||||
namespace spirv {
|
||||
class ModuleOp;
|
||||
class OwningSPIRVModuleRef;
|
||||
|
||||
/// Serializes the given SPIR-V `module` and writes to `binary`. On failure,
|
||||
/// reports errors to the error handler registered with the MLIR context for
|
||||
|
@ -31,9 +32,10 @@ LogicalResult serialize(ModuleOp module, SmallVectorImpl<uint32_t> &binary,
|
|||
|
||||
/// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp
|
||||
/// in the given `context`. Returns the ModuleOp on success; otherwise, reports
|
||||
/// errors to the error handler registered with `context` and returns
|
||||
/// llvm::None.
|
||||
Optional<ModuleOp> deserialize(ArrayRef<uint32_t> binary, MLIRContext *context);
|
||||
/// errors to the error handler registered with `context` and returns a null
|
||||
/// module.
|
||||
OwningSPIRVModuleRef deserialize(ArrayRef<uint32_t> binary,
|
||||
MLIRContext *context);
|
||||
|
||||
} // end namespace spirv
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#ifndef MLIR_IR_MODULE_H
|
||||
#define MLIR_IR_MODULE_H
|
||||
|
||||
#include "mlir/IR/OwningOpRefBase.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "llvm/Support/PointerLikeTypeTraits.h"
|
||||
|
||||
|
@ -122,40 +123,10 @@ public:
|
|||
};
|
||||
|
||||
/// This class acts as an owning reference to a module, and will automatically
|
||||
/// destroy the held module if valid.
|
||||
class OwningModuleRef {
|
||||
/// destroy the held module on destruction if the held module is valid.
|
||||
class OwningModuleRef : public OwningOpRefBase<ModuleOp> {
|
||||
public:
|
||||
OwningModuleRef(std::nullptr_t = nullptr) {}
|
||||
OwningModuleRef(ModuleOp module) : module(module) {}
|
||||
OwningModuleRef(OwningModuleRef &&other) : module(other.release()) {}
|
||||
~OwningModuleRef() {
|
||||
if (module)
|
||||
module.erase();
|
||||
}
|
||||
|
||||
// Assign from another module reference.
|
||||
OwningModuleRef &operator=(OwningModuleRef &&other) {
|
||||
if (module)
|
||||
module.erase();
|
||||
module = other.release();
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Allow accessing the internal module.
|
||||
ModuleOp get() const { return module; }
|
||||
ModuleOp operator*() const { return module; }
|
||||
ModuleOp *operator->() { return &module; }
|
||||
explicit operator bool() const { return module; }
|
||||
|
||||
/// Release the referenced module.
|
||||
ModuleOp release() {
|
||||
ModuleOp released;
|
||||
std::swap(released, module);
|
||||
return released;
|
||||
}
|
||||
|
||||
private:
|
||||
ModuleOp module;
|
||||
using OwningOpRefBase<ModuleOp>::OwningOpRefBase;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
//===- OwningOpRefBase.h - MLIR OwningOpRefBase -----------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file provides a base class for owning op refs.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_IR_OWNINGOPREFBASE_H
|
||||
#define MLIR_IR_OWNINGOPREFBASE_H
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// This class acts as an owning reference to an op, and will automatically
|
||||
/// destroy the held op on destruction if the held op is valid.
|
||||
///
|
||||
/// Note that OpBuilder and related functionality should be highly preferred
|
||||
/// instead, and this should only be used in situations where existing solutions
|
||||
/// are not viable.
|
||||
template <typename OpTy>
|
||||
class OwningOpRefBase {
|
||||
public:
|
||||
OwningOpRefBase(std::nullptr_t = nullptr) {}
|
||||
OwningOpRefBase(OpTy op) : op(op) {}
|
||||
OwningOpRefBase(OwningOpRefBase &&other) : op(other.release()) {}
|
||||
~OwningOpRefBase() {
|
||||
if (op)
|
||||
op.erase();
|
||||
}
|
||||
|
||||
// Assign from another op reference.
|
||||
OwningOpRefBase &operator=(OwningOpRefBase &&other) {
|
||||
if (op)
|
||||
op.erase();
|
||||
op = other.release();
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Allow accessing the internal op.
|
||||
OpTy get() const { return op; }
|
||||
OpTy operator*() const { return op; }
|
||||
OpTy *operator->() { return &op; }
|
||||
explicit operator bool() const { return op; }
|
||||
|
||||
/// Release the referenced op.
|
||||
OpTy release() {
|
||||
OpTy released;
|
||||
std::swap(released, op);
|
||||
return released;
|
||||
}
|
||||
|
||||
private:
|
||||
OpTy op;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_IR_OWNINGOPREFBASE_H
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVModule.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
|
@ -2516,12 +2517,12 @@ Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
|
|||
#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
|
||||
} // namespace
|
||||
|
||||
Optional<spirv::ModuleOp> spirv::deserialize(ArrayRef<uint32_t> binary,
|
||||
MLIRContext *context) {
|
||||
spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef<uint32_t> binary,
|
||||
MLIRContext *context) {
|
||||
Deserializer deserializer(binary, context);
|
||||
|
||||
if (failed(deserializer.deserialize()))
|
||||
return llvm::None;
|
||||
return nullptr;
|
||||
|
||||
return deserializer.collect();
|
||||
return deserializer.collect().getValueOr(nullptr);
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVModule.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/Serialization.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -49,13 +50,13 @@ static OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
|
|||
auto binary = llvm::makeArrayRef(reinterpret_cast<const uint32_t *>(start),
|
||||
size / sizeof(uint32_t));
|
||||
|
||||
auto spirvModule = spirv::deserialize(binary, context);
|
||||
spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context);
|
||||
if (!spirvModule)
|
||||
return {};
|
||||
|
||||
OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
|
||||
input->getBufferIdentifier(), /*line=*/0, /*column=*/0, context)));
|
||||
module->getBody()->push_front(spirvModule->getOperation());
|
||||
module->getBody()->push_front(spirvModule.release());
|
||||
|
||||
return module;
|
||||
}
|
||||
|
@ -136,14 +137,14 @@ static LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr,
|
|||
return failure();
|
||||
|
||||
// Then deserialize to get back a SPIR-V module.
|
||||
auto spirvModule = spirv::deserialize(binary, context);
|
||||
spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context);
|
||||
if (!spirvModule)
|
||||
return failure();
|
||||
|
||||
// Wrap around in a new MLIR module.
|
||||
OwningModuleRef dstModule(ModuleOp::create(FileLineColLoc::get(
|
||||
/*filename=*/"", /*line=*/0, /*column=*/0, context)));
|
||||
dstModule->getBody()->push_front(spirvModule->getOperation());
|
||||
dstModule->getBody()->push_front(spirvModule.release());
|
||||
dstModule->print(output);
|
||||
|
||||
return mlir::success();
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVModule.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/Serialization.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
|
@ -46,7 +47,7 @@ protected:
|
|||
}
|
||||
|
||||
/// Performs deserialization and returns the constructed spv.module op.
|
||||
Optional<spirv::ModuleOp> deserialize() {
|
||||
spirv::OwningSPIRVModuleRef deserialize() {
|
||||
return spirv::deserialize(binary, &context);
|
||||
}
|
||||
|
||||
|
@ -130,27 +131,27 @@ protected:
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TEST_F(DeserializationTest, EmptyModuleFailure) {
|
||||
ASSERT_EQ(llvm::None, deserialize());
|
||||
ASSERT_FALSE(deserialize());
|
||||
expectDiagnostic("SPIR-V binary module must have a 5-word header");
|
||||
}
|
||||
|
||||
TEST_F(DeserializationTest, WrongMagicNumberFailure) {
|
||||
addHeader();
|
||||
binary.front() = 0xdeadbeef; // Change to a wrong magic number
|
||||
ASSERT_EQ(llvm::None, deserialize());
|
||||
ASSERT_FALSE(deserialize());
|
||||
expectDiagnostic("incorrect magic number");
|
||||
}
|
||||
|
||||
TEST_F(DeserializationTest, OnlyHeaderSuccess) {
|
||||
addHeader();
|
||||
EXPECT_NE(llvm::None, deserialize());
|
||||
EXPECT_TRUE(deserialize());
|
||||
}
|
||||
|
||||
TEST_F(DeserializationTest, ZeroWordCountFailure) {
|
||||
addHeader();
|
||||
binary.push_back(0); // OpNop with zero word count
|
||||
|
||||
ASSERT_EQ(llvm::None, deserialize());
|
||||
ASSERT_FALSE(deserialize());
|
||||
expectDiagnostic("word count cannot be zero");
|
||||
}
|
||||
|
||||
|
@ -160,7 +161,7 @@ TEST_F(DeserializationTest, InsufficientWordFailure) {
|
|||
static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
|
||||
// Missing word for type <id>
|
||||
|
||||
ASSERT_EQ(llvm::None, deserialize());
|
||||
ASSERT_FALSE(deserialize());
|
||||
expectDiagnostic("insufficient words for the last instruction");
|
||||
}
|
||||
|
||||
|
@ -172,7 +173,7 @@ TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
|
|||
addHeader();
|
||||
addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
|
||||
|
||||
ASSERT_EQ(llvm::None, deserialize());
|
||||
ASSERT_FALSE(deserialize());
|
||||
expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
|
||||
}
|
||||
|
||||
|
@ -198,7 +199,7 @@ TEST_F(DeserializationTest, OpMemberNameSuccess) {
|
|||
addInstruction(spirv::Opcode::OpMemberName, operands2);
|
||||
|
||||
binary.append(typeDecl.begin(), typeDecl.end());
|
||||
EXPECT_NE(llvm::None, deserialize());
|
||||
EXPECT_TRUE(deserialize());
|
||||
}
|
||||
|
||||
TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
|
||||
|
@ -215,7 +216,7 @@ TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
|
|||
addInstruction(spirv::Opcode::OpMemberName, operands1);
|
||||
|
||||
binary.append(typeDecl.begin(), typeDecl.end());
|
||||
ASSERT_EQ(llvm::None, deserialize());
|
||||
ASSERT_FALSE(deserialize());
|
||||
expectDiagnostic("OpMemberName must have at least 3 operands");
|
||||
}
|
||||
|
||||
|
@ -234,7 +235,7 @@ TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
|
|||
addInstruction(spirv::Opcode::OpMemberName, operands);
|
||||
|
||||
binary.append(typeDecl.begin(), typeDecl.end());
|
||||
ASSERT_EQ(llvm::None, deserialize());
|
||||
ASSERT_FALSE(deserialize());
|
||||
expectDiagnostic("unexpected trailing words in OpMemberName instruction");
|
||||
}
|
||||
|
||||
|
@ -249,7 +250,7 @@ TEST_F(DeserializationTest, FunctionMissingEndFailure) {
|
|||
addFunction(voidType, fnType);
|
||||
// Missing OpFunctionEnd
|
||||
|
||||
ASSERT_EQ(llvm::None, deserialize());
|
||||
ASSERT_FALSE(deserialize());
|
||||
expectDiagnostic("expected OpFunctionEnd instruction");
|
||||
}
|
||||
|
||||
|
@ -261,7 +262,7 @@ TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
|
|||
addFunction(voidType, fnType);
|
||||
// Missing OpFunctionParameter
|
||||
|
||||
ASSERT_EQ(llvm::None, deserialize());
|
||||
ASSERT_FALSE(deserialize());
|
||||
expectDiagnostic("expected OpFunctionParameter instruction");
|
||||
}
|
||||
|
||||
|
@ -274,7 +275,7 @@ TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
|
|||
addReturn();
|
||||
addFunctionEnd();
|
||||
|
||||
ASSERT_EQ(llvm::None, deserialize());
|
||||
ASSERT_FALSE(deserialize());
|
||||
expectDiagnostic("a basic block must start with OpLabel");
|
||||
}
|
||||
|
||||
|
@ -287,6 +288,6 @@ TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
|
|||
addReturn();
|
||||
addFunctionEnd();
|
||||
|
||||
ASSERT_EQ(llvm::None, deserialize());
|
||||
ASSERT_FALSE(deserialize());
|
||||
expectDiagnostic("OpLabel should only have result <id>");
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue