[Support] Extend TypeConversionPattern to apply to a single op type

This commit is contained in:
Morten Borup Petersen 2023-09-28 12:52:54 +00:00
parent 6f24c2c71d
commit 0f37de27f2
2 changed files with 31 additions and 11 deletions

View File

@ -15,6 +15,11 @@
namespace circt { namespace circt {
// Performs type conversion on the given operation.
LogicalResult doTypeConversion(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,
const TypeConverter *typeConverter);
/// Generic pattern which replaces an operation by one of the same operation /// Generic pattern which replaces an operation by one of the same operation
/// name, but with converted attributes, operands, and result types to eliminate /// name, but with converted attributes, operands, and result types to eliminate
/// illegal types. Uses generic builders based on OperationState to make sure /// illegal types. Uses generic builders based on OperationState to make sure
@ -29,7 +34,23 @@ public:
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override; ConversionPatternRewriter &rewriter) const override {
return doTypeConversion(op, operands, rewriter, getTypeConverter());
}
};
// Specialization of the above which targets a specific operation.
template <typename OpTy>
struct TypeOpConversionPattern : public mlir::OpConversionPattern<OpTy> {
using mlir::OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename mlir::OpConversionPattern<OpTy>::OpAdaptor;
LogicalResult
matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return doTypeConversion(op.getOperation(), adaptor.getOperands(), rewriter,
this->getTypeConverter());
}
}; };
} // namespace circt } // namespace circt

View File

@ -34,9 +34,9 @@ static hw::ModuleType convertModuleType(const TypeConverter &typeConverter,
return hw::ModuleType::get(type.getContext(), ports); return hw::ModuleType::get(type.getContext(), ports);
} }
LogicalResult TypeConversionPattern::matchAndRewrite( LogicalResult circt::doTypeConversion(Operation *op, ValueRange operands,
Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter,
ConversionPatternRewriter &rewriter) const { const TypeConverter *typeConverter) {
// Convert the TypeAttrs. // Convert the TypeAttrs.
llvm::SmallVector<NamedAttribute, 4> newAttrs; llvm::SmallVector<NamedAttribute, 4> newAttrs;
newAttrs.reserve(op->getAttrs().size()); newAttrs.reserve(op->getAttrs().size());
@ -46,11 +46,11 @@ LogicalResult TypeConversionPattern::matchAndRewrite(
// TypeConvert::convertType doesn't handle function types, so we need to // TypeConvert::convertType doesn't handle function types, so we need to
// handle them manually. // handle them manually.
if (auto funcType = innerType.dyn_cast<FunctionType>()) if (auto funcType = innerType.dyn_cast<FunctionType>())
innerType = convertFunctionType(*getTypeConverter(), funcType); innerType = convertFunctionType(*typeConverter, funcType);
else if (auto modType = innerType.dyn_cast<hw::ModuleType>()) else if (auto modType = innerType.dyn_cast<hw::ModuleType>())
innerType = convertModuleType(*getTypeConverter(), modType); innerType = convertModuleType(*typeConverter, modType);
else else
innerType = getTypeConverter()->convertType(innerType); innerType = typeConverter->convertType(innerType);
newAttrs.emplace_back(attr.getName(), TypeAttr::get(innerType)); newAttrs.emplace_back(attr.getName(), TypeAttr::get(innerType));
} else { } else {
newAttrs.push_back(attr); newAttrs.push_back(attr);
@ -59,8 +59,7 @@ LogicalResult TypeConversionPattern::matchAndRewrite(
// Convert the result types. // Convert the result types.
llvm::SmallVector<Type, 4> newResults; llvm::SmallVector<Type, 4> newResults;
if (failed( if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
getTypeConverter()->convertTypes(op->getResultTypes(), newResults)))
return rewriter.notifyMatchFailure(op->getLoc(), "type conversion failed"); return rewriter.notifyMatchFailure(op->getLoc(), "type conversion failed");
// Build the state for the edited clone. // Build the state for the edited clone.
@ -88,11 +87,11 @@ LogicalResult TypeConversionPattern::matchAndRewrite(
// Move the region and convert the region args. // Move the region and convert the region args.
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
TypeConverter::SignatureConversion result(newRegion->getNumArguments()); TypeConverter::SignatureConversion result(newRegion->getNumArguments());
if (failed(getTypeConverter()->convertSignatureArgs( if (failed(typeConverter->convertSignatureArgs(
newRegion->getArgumentTypes(), result))) newRegion->getArgumentTypes(), result)))
return rewriter.notifyMatchFailure(op->getLoc(), return rewriter.notifyMatchFailure(op->getLoc(),
"type conversion failed"); "type conversion failed");
rewriter.applySignatureConversion(newRegion, result, getTypeConverter()); rewriter.applySignatureConversion(newRegion, result, typeConverter);
// Apply the argument locations. // Apply the argument locations.
for (auto [arg, loc] : llvm::zip(newRegion->getArguments(), argLocs)) for (auto [arg, loc] : llvm::zip(newRegion->getArguments(), argLocs))