[FIRRTL][CAPI] Allow constructing integers larger than 64 bits

This commit is contained in:
Asuna 2024-04-03 23:00:12 +02:00
parent 923a5ee13d
commit 8868394e32
3 changed files with 89 additions and 4 deletions

View File

@ -146,6 +146,11 @@ MLIR_CAPI_EXPORTED MlirAttribute firrtlAttrGetMemDir(MlirContext ctx,
MLIR_CAPI_EXPORTED MlirAttribute
firrtlAttrGetEventControl(MlirContext ctx, FIRRTLEventControl eventControl);
// Workaround:
// https://github.com/llvm/llvm-project/issues/84190#issuecomment-2035552035
MLIR_CAPI_EXPORTED MlirAttribute firrtlAttrGetIntegerFromString(
MlirType type, unsigned numBits, MlirStringRef str, uint8_t radix);
//===----------------------------------------------------------------------===//
// Utility API.
//===----------------------------------------------------------------------===//

View File

@ -281,6 +281,12 @@ MlirAttribute firrtlAttrGetEventControl(MlirContext ctx,
return wrap(EventControlAttr::get(unwrap(ctx), value));
}
MlirAttribute firrtlAttrGetIntegerFromString(MlirType type, unsigned numBits,
MlirStringRef str, uint8_t radix) {
auto value = APInt{numBits, unwrap(str), radix};
return wrap(IntegerAttr::get(unwrap(type), value));
}
FIRRTLValueFlow firrtlValueFoldFlow(MlirValue value, FIRRTLValueFlow flow) {
Flow flowValue;

View File

@ -19,9 +19,15 @@
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
void exportCallback(MlirStringRef message, void *userData) {
printf("%.*s", (int)message.length, message.data);
void dumpCallback(MlirStringRef message, void *userData) {
fprintf(stderr, "%.*s", (int)message.length, message.data);
}
void appendBufferCallback(MlirStringRef message, void *userData) {
char *buffer = (char *)userData;
sprintf(buffer + strlen(buffer), "%.*s", (int)message.length, message.data);
}
void testExport(MlirContext ctx) {
@ -39,7 +45,7 @@ void testExport(MlirContext ctx) {
MlirModule module =
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(testFIR));
MlirLogicalResult result = mlirExportFIRRTL(module, exportCallback, NULL);
MlirLogicalResult result = mlirExportFIRRTL(module, dumpCallback, NULL);
assert(mlirLogicalResultIsSuccess(result));
// CHECK: FIRRTL version 4.0.0
@ -104,18 +110,86 @@ void testImportAnnotations(MlirContext ctx) {
firCircuit, mlirStringRefCreateFromCString("rawAnnotations"),
rawAnnotationsAttr);
mlirOperationPrint(mlirModuleGetOperation(module), exportCallback, NULL);
mlirOperationPrint(mlirModuleGetOperation(module), dumpCallback, NULL);
// clang-format off
// CHECK: firrtl.circuit "AnnoTest" attributes {rawAnnotations = [{class = "firrtl.transforms.DontTouchAnnotation", target = "~AnnoTest|AnnoTest>in"}]} {
// clang-format on
}
void assertAttrEqual(MlirAttribute lhs, MlirAttribute rhs) {
char lhsBuffer[256] = {0}, rhsBuffer[256] = {0};
mlirAttributePrint(lhs, appendBufferCallback, lhsBuffer);
mlirAttributePrint(rhs, appendBufferCallback, rhsBuffer);
assert(strcmp(lhsBuffer, rhsBuffer) == 0);
}
void testAttrGetIntegerFromString(MlirContext ctx) {
// large negative hex
assertAttrEqual(
mlirAttributeParseGet(
ctx, mlirStringRefCreateFromCString("0xFF0000000000000000 : i72")),
firrtlAttrGetIntegerFromString(
mlirIntegerTypeGet(ctx, 72), 72,
mlirStringRefCreateFromCString("FF0000000000000000"), 16));
// large positive hex
assertAttrEqual(
mlirAttributeParseGet(
ctx, mlirStringRefCreateFromCString("0xFF0000000000000000 : i73")),
firrtlAttrGetIntegerFromString(
mlirIntegerTypeGet(ctx, 73), 73,
mlirStringRefCreateFromCString("FF0000000000000000"), 16));
// large negative dec
assertAttrEqual(
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString(
"-12345678912345678912345 : i75")),
firrtlAttrGetIntegerFromString(
mlirIntegerTypeGet(ctx, 75), 75,
mlirStringRefCreateFromCString("-12345678912345678912345"), 10));
// large positive dec
assertAttrEqual(
mlirAttributeParseGet(
ctx, mlirStringRefCreateFromCString("12345678912345678912345 : i75")),
firrtlAttrGetIntegerFromString(
mlirIntegerTypeGet(ctx, 75), 75,
mlirStringRefCreateFromCString("12345678912345678912345"), 10));
// small negative hex
assertAttrEqual(
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0xFF : i8")),
firrtlAttrGetIntegerFromString(mlirIntegerTypeGet(ctx, 8), 8,
mlirStringRefCreateFromCString("FF"), 16));
// small positive hex
assertAttrEqual(
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0xFF : i9")),
firrtlAttrGetIntegerFromString(mlirIntegerTypeGet(ctx, 9), 9,
mlirStringRefCreateFromCString("FF"), 16));
// small negative dec
assertAttrEqual(mlirAttributeParseGet(
ctx, mlirStringRefCreateFromCString("-114514 : i18")),
firrtlAttrGetIntegerFromString(
mlirIntegerTypeGet(ctx, 18), 18,
mlirStringRefCreateFromCString("-114514"), 10));
// small positive dec
assertAttrEqual(mlirAttributeParseGet(
ctx, mlirStringRefCreateFromCString("114514 : i18")),
firrtlAttrGetIntegerFromString(
mlirIntegerTypeGet(ctx, 18), 18,
mlirStringRefCreateFromCString("114514"), 10));
}
int main(void) {
MlirContext ctx = mlirContextCreate();
mlirDialectHandleLoadDialect(mlirGetDialectHandle__firrtl__(), ctx);
testExport(ctx);
testValueFoldFlow(ctx);
testImportAnnotations(ctx);
testAttrGetIntegerFromString(ctx);
return 0;
}