[mlir][spirv] Introduce OwningSPIRVModuleRef for ownership

Similar to OwningModuleRef, OwningSPIRVModuleRef signals ownership
transfer clearly. This is useful for APIs like spirv::deserialize,
where a spirv::ModuleOp is returned by deserializing SPIR-V binary
module.

This addresses the ASAN error as reported in
https://bugs.llvm.org/show_bug.cgi?id=46272

Differential Revision: https://reviews.llvm.org/D81652
This commit is contained in:
Lei Zhang 2020-07-07 08:28:25 -04:00
parent 7bf299c8d8
commit b80508703f
7 changed files with 127 additions and 58 deletions

View File

@ -0,0 +1,29 @@
//===- SPIRVModule.h - SPIR-V Module Utilities ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPIRV_SPIRVMODULE_H
#define MLIR_DIALECT_SPIRV_SPIRVMODULE_H
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/IR/OwningOpRefBase.h"
namespace mlir {
namespace spirv {
/// This class acts as an owning reference to a SPIR-V module, and will
/// automatically destroy the held module on destruction if the held module
/// is valid.
class OwningSPIRVModuleRef : public OwningOpRefBase<spirv::ModuleOp> {
public:
using OwningOpRefBase<spirv::ModuleOp>::OwningOpRefBase;
};
} // end namespace spirv
} // end namespace mlir
#endif // MLIR_DIALECT_SPIRV_SPIRVMODULE_H

View File

@ -22,6 +22,7 @@ class MLIRContext;
namespace spirv {
class ModuleOp;
class OwningSPIRVModuleRef;
/// Serializes the given SPIR-V `module` and writes to `binary`. On failure,
/// reports errors to the error handler registered with the MLIR context for
@ -31,9 +32,10 @@ LogicalResult serialize(ModuleOp module, SmallVectorImpl<uint32_t> &binary,
/// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp
/// in the given `context`. Returns the ModuleOp on success; otherwise, reports
/// errors to the error handler registered with `context` and returns
/// llvm::None.
Optional<ModuleOp> deserialize(ArrayRef<uint32_t> binary, MLIRContext *context);
/// errors to the error handler registered with `context` and returns a null
/// module.
OwningSPIRVModuleRef deserialize(ArrayRef<uint32_t> binary,
MLIRContext *context);
} // end namespace spirv
} // end namespace mlir

View File

@ -13,6 +13,7 @@
#ifndef MLIR_IR_MODULE_H
#define MLIR_IR_MODULE_H
#include "mlir/IR/OwningOpRefBase.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
@ -122,40 +123,10 @@ public:
};
/// This class acts as an owning reference to a module, and will automatically
/// destroy the held module if valid.
class OwningModuleRef {
/// destroy the held module on destruction if the held module is valid.
class OwningModuleRef : public OwningOpRefBase<ModuleOp> {
public:
OwningModuleRef(std::nullptr_t = nullptr) {}
OwningModuleRef(ModuleOp module) : module(module) {}
OwningModuleRef(OwningModuleRef &&other) : module(other.release()) {}
~OwningModuleRef() {
if (module)
module.erase();
}
// Assign from another module reference.
OwningModuleRef &operator=(OwningModuleRef &&other) {
if (module)
module.erase();
module = other.release();
return *this;
}
/// Allow accessing the internal module.
ModuleOp get() const { return module; }
ModuleOp operator*() const { return module; }
ModuleOp *operator->() { return &module; }
explicit operator bool() const { return module; }
/// Release the referenced module.
ModuleOp release() {
ModuleOp released;
std::swap(released, module);
return released;
}
private:
ModuleOp module;
using OwningOpRefBase<ModuleOp>::OwningOpRefBase;
};
} // end namespace mlir

View File

@ -0,0 +1,64 @@
//===- OwningOpRefBase.h - MLIR OwningOpRefBase -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file provides a base class for owning op refs.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_OWNINGOPREFBASE_H
#define MLIR_IR_OWNINGOPREFBASE_H
#include <utility>
namespace mlir {
/// This class acts as an owning reference to an op, and will automatically
/// destroy the held op on destruction if the held op is valid.
///
/// Note that OpBuilder and related functionality should be highly preferred
/// instead, and this should only be used in situations where existing solutions
/// are not viable.
template <typename OpTy>
class OwningOpRefBase {
public:
OwningOpRefBase(std::nullptr_t = nullptr) {}
OwningOpRefBase(OpTy op) : op(op) {}
OwningOpRefBase(OwningOpRefBase &&other) : op(other.release()) {}
~OwningOpRefBase() {
if (op)
op.erase();
}
// Assign from another op reference.
OwningOpRefBase &operator=(OwningOpRefBase &&other) {
if (op)
op.erase();
op = other.release();
return *this;
}
/// Allow accessing the internal op.
OpTy get() const { return op; }
OpTy operator*() const { return op; }
OpTy *operator->() { return &op; }
explicit operator bool() const { return op; }
/// Release the referenced op.
OpTy release() {
OpTy released;
std::swap(released, op);
return released;
}
private:
OpTy op;
};
} // end namespace mlir
#endif // MLIR_IR_OWNINGOPREFBASE_H

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVModule.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@ -2516,12 +2517,12 @@ Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
} // namespace
Optional<spirv::ModuleOp> spirv::deserialize(ArrayRef<uint32_t> binary,
MLIRContext *context) {
spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef<uint32_t> binary,
MLIRContext *context) {
Deserializer deserializer(binary, context);
if (failed(deserializer.deserialize()))
return llvm::None;
return nullptr;
return deserializer.collect();
return deserializer.collect().getValueOr(nullptr);
}

View File

@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/SPIRVModule.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/IR/Builders.h"
@ -49,13 +50,13 @@ static OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
auto binary = llvm::makeArrayRef(reinterpret_cast<const uint32_t *>(start),
size / sizeof(uint32_t));
auto spirvModule = spirv::deserialize(binary, context);
spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context);
if (!spirvModule)
return {};
OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
input->getBufferIdentifier(), /*line=*/0, /*column=*/0, context)));
module->getBody()->push_front(spirvModule->getOperation());
module->getBody()->push_front(spirvModule.release());
return module;
}
@ -136,14 +137,14 @@ static LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr,
return failure();
// Then deserialize to get back a SPIR-V module.
auto spirvModule = spirv::deserialize(binary, context);
spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context);
if (!spirvModule)
return failure();
// Wrap around in a new MLIR module.
OwningModuleRef dstModule(ModuleOp::create(FileLineColLoc::get(
/*filename=*/"", /*line=*/0, /*column=*/0, context)));
dstModule->getBody()->push_front(spirvModule->getOperation());
dstModule->getBody()->push_front(spirvModule.release());
dstModule->print(output);
return mlir::success();

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVModule.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/IR/Diagnostics.h"
@ -46,7 +47,7 @@ protected:
}
/// Performs deserialization and returns the constructed spv.module op.
Optional<spirv::ModuleOp> deserialize() {
spirv::OwningSPIRVModuleRef deserialize() {
return spirv::deserialize(binary, &context);
}
@ -130,27 +131,27 @@ protected:
//===----------------------------------------------------------------------===//
TEST_F(DeserializationTest, EmptyModuleFailure) {
ASSERT_EQ(llvm::None, deserialize());
ASSERT_FALSE(deserialize());
expectDiagnostic("SPIR-V binary module must have a 5-word header");
}
TEST_F(DeserializationTest, WrongMagicNumberFailure) {
addHeader();
binary.front() = 0xdeadbeef; // Change to a wrong magic number
ASSERT_EQ(llvm::None, deserialize());
ASSERT_FALSE(deserialize());
expectDiagnostic("incorrect magic number");
}
TEST_F(DeserializationTest, OnlyHeaderSuccess) {
addHeader();
EXPECT_NE(llvm::None, deserialize());
EXPECT_TRUE(deserialize());
}
TEST_F(DeserializationTest, ZeroWordCountFailure) {
addHeader();
binary.push_back(0); // OpNop with zero word count
ASSERT_EQ(llvm::None, deserialize());
ASSERT_FALSE(deserialize());
expectDiagnostic("word count cannot be zero");
}
@ -160,7 +161,7 @@ TEST_F(DeserializationTest, InsufficientWordFailure) {
static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
// Missing word for type <id>
ASSERT_EQ(llvm::None, deserialize());
ASSERT_FALSE(deserialize());
expectDiagnostic("insufficient words for the last instruction");
}
@ -172,7 +173,7 @@ TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
addHeader();
addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
ASSERT_EQ(llvm::None, deserialize());
ASSERT_FALSE(deserialize());
expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
}
@ -198,7 +199,7 @@ TEST_F(DeserializationTest, OpMemberNameSuccess) {
addInstruction(spirv::Opcode::OpMemberName, operands2);
binary.append(typeDecl.begin(), typeDecl.end());
EXPECT_NE(llvm::None, deserialize());
EXPECT_TRUE(deserialize());
}
TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
@ -215,7 +216,7 @@ TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
addInstruction(spirv::Opcode::OpMemberName, operands1);
binary.append(typeDecl.begin(), typeDecl.end());
ASSERT_EQ(llvm::None, deserialize());
ASSERT_FALSE(deserialize());
expectDiagnostic("OpMemberName must have at least 3 operands");
}
@ -234,7 +235,7 @@ TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
addInstruction(spirv::Opcode::OpMemberName, operands);
binary.append(typeDecl.begin(), typeDecl.end());
ASSERT_EQ(llvm::None, deserialize());
ASSERT_FALSE(deserialize());
expectDiagnostic("unexpected trailing words in OpMemberName instruction");
}
@ -249,7 +250,7 @@ TEST_F(DeserializationTest, FunctionMissingEndFailure) {
addFunction(voidType, fnType);
// Missing OpFunctionEnd
ASSERT_EQ(llvm::None, deserialize());
ASSERT_FALSE(deserialize());
expectDiagnostic("expected OpFunctionEnd instruction");
}
@ -261,7 +262,7 @@ TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
addFunction(voidType, fnType);
// Missing OpFunctionParameter
ASSERT_EQ(llvm::None, deserialize());
ASSERT_FALSE(deserialize());
expectDiagnostic("expected OpFunctionParameter instruction");
}
@ -274,7 +275,7 @@ TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
addReturn();
addFunctionEnd();
ASSERT_EQ(llvm::None, deserialize());
ASSERT_FALSE(deserialize());
expectDiagnostic("a basic block must start with OpLabel");
}
@ -287,6 +288,6 @@ TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
addReturn();
addFunctionEnd();
ASSERT_EQ(llvm::None, deserialize());
ASSERT_FALSE(deserialize());
expectDiagnostic("OpLabel should only have result <id>");
}