mirror of https://github.com/llvm/circt.git
[Support] Extend TypeConversionPattern to apply to a single op type
This commit is contained in:
parent
6f24c2c71d
commit
0f37de27f2
|
@ -15,6 +15,11 @@
|
||||||
|
|
||||||
namespace circt {
|
namespace circt {
|
||||||
|
|
||||||
|
// Performs type conversion on the given operation.
|
||||||
|
LogicalResult doTypeConversion(Operation *op, ValueRange operands,
|
||||||
|
ConversionPatternRewriter &rewriter,
|
||||||
|
const TypeConverter *typeConverter);
|
||||||
|
|
||||||
/// Generic pattern which replaces an operation by one of the same operation
|
/// Generic pattern which replaces an operation by one of the same operation
|
||||||
/// name, but with converted attributes, operands, and result types to eliminate
|
/// name, but with converted attributes, operands, and result types to eliminate
|
||||||
/// illegal types. Uses generic builders based on OperationState to make sure
|
/// illegal types. Uses generic builders based on OperationState to make sure
|
||||||
|
@ -29,7 +34,23 @@ public:
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override;
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
return doTypeConversion(op, operands, rewriter, getTypeConverter());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Specialization of the above which targets a specific operation.
|
||||||
|
template <typename OpTy>
|
||||||
|
struct TypeOpConversionPattern : public mlir::OpConversionPattern<OpTy> {
|
||||||
|
using mlir::OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
|
using OpAdaptor = typename mlir::OpConversionPattern<OpTy>::OpAdaptor;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(OpTy op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
return doTypeConversion(op.getOperation(), adaptor.getOperands(), rewriter,
|
||||||
|
this->getTypeConverter());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace circt
|
} // namespace circt
|
||||||
|
|
|
@ -34,9 +34,9 @@ static hw::ModuleType convertModuleType(const TypeConverter &typeConverter,
|
||||||
return hw::ModuleType::get(type.getContext(), ports);
|
return hw::ModuleType::get(type.getContext(), ports);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult TypeConversionPattern::matchAndRewrite(
|
LogicalResult circt::doTypeConversion(Operation *op, ValueRange operands,
|
||||||
Operation *op, ArrayRef<Value> operands,
|
ConversionPatternRewriter &rewriter,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
const TypeConverter *typeConverter) {
|
||||||
// Convert the TypeAttrs.
|
// Convert the TypeAttrs.
|
||||||
llvm::SmallVector<NamedAttribute, 4> newAttrs;
|
llvm::SmallVector<NamedAttribute, 4> newAttrs;
|
||||||
newAttrs.reserve(op->getAttrs().size());
|
newAttrs.reserve(op->getAttrs().size());
|
||||||
|
@ -46,11 +46,11 @@ LogicalResult TypeConversionPattern::matchAndRewrite(
|
||||||
// TypeConvert::convertType doesn't handle function types, so we need to
|
// TypeConvert::convertType doesn't handle function types, so we need to
|
||||||
// handle them manually.
|
// handle them manually.
|
||||||
if (auto funcType = innerType.dyn_cast<FunctionType>())
|
if (auto funcType = innerType.dyn_cast<FunctionType>())
|
||||||
innerType = convertFunctionType(*getTypeConverter(), funcType);
|
innerType = convertFunctionType(*typeConverter, funcType);
|
||||||
else if (auto modType = innerType.dyn_cast<hw::ModuleType>())
|
else if (auto modType = innerType.dyn_cast<hw::ModuleType>())
|
||||||
innerType = convertModuleType(*getTypeConverter(), modType);
|
innerType = convertModuleType(*typeConverter, modType);
|
||||||
else
|
else
|
||||||
innerType = getTypeConverter()->convertType(innerType);
|
innerType = typeConverter->convertType(innerType);
|
||||||
newAttrs.emplace_back(attr.getName(), TypeAttr::get(innerType));
|
newAttrs.emplace_back(attr.getName(), TypeAttr::get(innerType));
|
||||||
} else {
|
} else {
|
||||||
newAttrs.push_back(attr);
|
newAttrs.push_back(attr);
|
||||||
|
@ -59,8 +59,7 @@ LogicalResult TypeConversionPattern::matchAndRewrite(
|
||||||
|
|
||||||
// Convert the result types.
|
// Convert the result types.
|
||||||
llvm::SmallVector<Type, 4> newResults;
|
llvm::SmallVector<Type, 4> newResults;
|
||||||
if (failed(
|
if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
|
||||||
getTypeConverter()->convertTypes(op->getResultTypes(), newResults)))
|
|
||||||
return rewriter.notifyMatchFailure(op->getLoc(), "type conversion failed");
|
return rewriter.notifyMatchFailure(op->getLoc(), "type conversion failed");
|
||||||
|
|
||||||
// Build the state for the edited clone.
|
// Build the state for the edited clone.
|
||||||
|
@ -88,11 +87,11 @@ LogicalResult TypeConversionPattern::matchAndRewrite(
|
||||||
// Move the region and convert the region args.
|
// Move the region and convert the region args.
|
||||||
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
|
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
|
||||||
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
|
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
|
||||||
if (failed(getTypeConverter()->convertSignatureArgs(
|
if (failed(typeConverter->convertSignatureArgs(
|
||||||
newRegion->getArgumentTypes(), result)))
|
newRegion->getArgumentTypes(), result)))
|
||||||
return rewriter.notifyMatchFailure(op->getLoc(),
|
return rewriter.notifyMatchFailure(op->getLoc(),
|
||||||
"type conversion failed");
|
"type conversion failed");
|
||||||
rewriter.applySignatureConversion(newRegion, result, getTypeConverter());
|
rewriter.applySignatureConversion(newRegion, result, typeConverter);
|
||||||
|
|
||||||
// Apply the argument locations.
|
// Apply the argument locations.
|
||||||
for (auto [arg, loc] : llvm::zip(newRegion->getArguments(), argLocs))
|
for (auto [arg, loc] : llvm::zip(newRegion->getArguments(), argLocs))
|
||||||
|
|
Loading…
Reference in New Issue