[ESI][Runtime] Filling out the type system (#6476)

Implementing missing classes and exposing to Python. Will build to
proper type-based serialization in Python.
This commit is contained in:
John Demme 2023-11-30 15:57:21 -08:00 committed by GitHub
parent a0ea4b4b4f
commit 25b69dc4f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 405 additions and 36 deletions

View File

@ -11,6 +11,10 @@
!sendI0 = !esi.bundle<[!esi.channel<i0> to "send"]>
!recvI0 = !esi.bundle<[!esi.channel<i0> to "recv"]>
!anyFrom = !esi.bundle<[
!esi.channel<!esi.any> from "recv",
!esi.channel<!esi.any> to "send"]>
esi.service.decl @HostComms {
esi.service.to_server @Send : !sendI8
esi.service.to_client @Recv : !recvI8
@ -33,6 +37,22 @@ hw.module @Loopback (in %clk: !seq.clock) {
esi.service.req.to_server %sendi0_bundle -> <@MyService::@Send> (#esi.appid<"mysvc_send">) : !sendI0
}
esi.service.std.func @funcs
!structFunc = !esi.bundle<[
!esi.channel<!hw.struct<a: ui4, b: si8>> to "arg",
!esi.channel<!hw.array<1xsi8>> from "result"]>
hw.module @LoopbackStruct() {
%callBundle = esi.service.req.to_client <@funcs::@call> (#esi.appid<"structFunc">) : !structFunc
%arg = esi.bundle.unpack %resultChan from %callBundle : !structFunc
%argData, %valid = esi.unwrap.vr %arg, %ready : !hw.struct<a: ui4, b: si8>
%resultElem = hw.struct_extract %argData["b"] : !hw.struct<a: ui4, b: si8>
%resultArray = hw.array_create %resultElem : si8
%resultChan, %ready = esi.wrap.vr %resultArray, %valid : !hw.array<1xsi8>
}
esi.mem.ram @MemA i64 x 20
!write = !hw.struct<address: i5, data: i64>
!writeBundle = !esi.bundle<[!esi.channel<!write> to "req", !esi.channel<i0> from "ack"]>
@ -46,8 +66,6 @@ hw.module @MemoryAccess1(in %clk : !seq.clock, in %rst : i1) {
esi.service.req.to_server %writeBundle -> <@MemA::@write> (#esi.appid<"internal_write">) : !writeBundle
}
esi.service.std.func @funcs
!func1Signature = !esi.bundle<[!esi.channel<i16> to "arg", !esi.channel<i16> from "result"]>
hw.module @CallableFunc1() {
%call = esi.service.req.to_client <@funcs::@call> (#esi.appid<"func1">) : !func1Signature
@ -63,4 +81,5 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) {
hw.instance "m2" @Loopback (clk: %clk: !seq.clock) -> () {esi.appid=#esi.appid<"loopback_inst"[1]>}
hw.instance "int_mem" @MemoryAccess1 (clk: %clk: !seq.clock, rst: %rst: i1) -> ()
hw.instance "func1" @CallableFunc1() -> ()
hw.instance "loopback_struct" @LoopbackStruct() -> ()
}

View File

@ -8,7 +8,33 @@ acc = esi.Accelerator(platform, sys.argv[2])
assert acc.sysinfo().esi_version() == 1
m = acc.manifest()
assert m.api_version == 1
print(m.type_table)
def strType(t: esi.Type) -> str:
if isinstance(t, esi.BundleType):
return "bundle<[{}]>".format(", ".join([
f"{name} {direction} {strType(ty)}" for (name, direction,
ty) in t.channels
]))
if isinstance(t, esi.ChannelType):
return f"channel<{strType(t.inner)}>"
if isinstance(t, esi.ArrayType):
return f"array<{strType(t.element)}, {t.size}>"
if isinstance(t, esi.StructType):
return "struct<{}>".format(", ".join(
["{name}: {strType(ty)}" for (name, ty) in t.fields]))
if isinstance(t, esi.BitsType):
return f"bits<{t.width}>"
if isinstance(t, esi.UIntType):
return f"uint<{t.width}>"
if isinstance(t, esi.SIntType):
return f"sint<{t.width}>"
assert False, f"unknown type: {t}"
for esiType in m.type_table:
print(f"{esiType}:")
print(f" {strType(esiType)}")
d = m.build_design(acc)
@ -21,6 +47,7 @@ assert appid.idx == 0
mysvc_send = loopback.ports[esi.AppID("mysvc_recv")].channels["recv"]
mysvc_send.connect()
mysvc_send.write([0])
assert str(mysvc_send.type) == "<!esi.channel<i0>>"
mysvc_send = loopback.ports[esi.AppID("mysvc_send")].channels["send"]
mysvc_send.connect()

View File

@ -47,14 +47,23 @@ constexpr uint32_t ExpectedVersionNumber = 0;
/// but used by higher level APIs which add types.
class ChannelPort {
public:
ChannelPort(const Type &type) : type(type) {}
virtual ~ChannelPort() = default;
virtual void connect() {}
virtual void disconnect() {}
const Type &getType() const { return type; }
private:
const Type &type;
};
/// A ChannelPort which sends data to the accelerator.
class WriteChannelPort : public ChannelPort {
public:
using ChannelPort::ChannelPort;
/// A very basic write API. Will likely change for performance reasons.
virtual void write(const void *data, size_t size) = 0;
};
@ -62,6 +71,8 @@ public:
/// A ChannelPort which reads data from the accelerator.
class ReadChannelPort : public ChannelPort {
public:
using ChannelPort::ChannelPort;
/// Specify a buffer to read into and a maximum size to read. Returns the
/// number of bytes read, or -1 on error. Basic API, will likely change for
/// performance reasons.

View File

@ -61,6 +61,91 @@ protected:
ChannelVector channels;
};
/// Channels are the basic communication primitives. They are unidirectional and
/// carry one values of one type.
class ChannelType : public Type {
public:
ChannelType(const ID &id, const Type &inner) : Type(id), inner(inner) {}
const Type &getInner() const { return inner; }
private:
const Type &inner;
};
/// The "any" type is a special type which can be used to represent any type, as
/// identified by the type id. Said type id is guaranteed to be present in the
/// manifest. Importantly, the "any" type id over the wire may not be a string
/// as it is in software.
class AnyType : public Type {
public:
AnyType(const ID &id) : Type(id) {}
};
/// Bit vectors include signed, unsigned, and signless integers.
class BitVectorType : public Type {
public:
BitVectorType(const ID &id, uint64_t width) : Type(id), width(width) {}
uint64_t getWidth() const { return width; }
private:
uint64_t width;
};
/// Bits are just an array of bits. They are not interpreted as a number but are
/// identified in the manifest as "signless" ints.
class BitsType : public BitVectorType {
public:
using BitVectorType::BitVectorType;
};
/// Integers are bit vectors which may be signed or unsigned and are interpreted
/// as numbers.
class IntegerType : public BitVectorType {
public:
using BitVectorType::BitVectorType;
};
/// Signed integer.
class SIntType : public IntegerType {
public:
using IntegerType::IntegerType;
};
/// Unsigned integer.
class UIntType : public IntegerType {
public:
using IntegerType::IntegerType;
};
/// Structs are an ordered collection of fields, each with a name and a type.
class StructType : public Type {
public:
using FieldVector = std::vector<std::tuple<std::string, const Type &>>;
StructType(const ID &id, const FieldVector &fields)
: Type(id), fields(fields) {}
const FieldVector &getFields() const { return fields; }
private:
FieldVector fields;
};
/// Arrays have a compile time specified (static) size and an element type.
class ArrayType : public Type {
public:
ArrayType(const ID &id, const Type &elementType, uint64_t size)
: Type(id), elementType(elementType), size(size) {}
const Type &getElementType() const { return elementType; }
uint64_t getSize() const { return size; }
private:
const Type &elementType;
uint64_t size;
};
} // namespace esi
#endif // ESI_TYPES_H

View File

@ -39,6 +39,8 @@ class Manifest::Impl {
friend class ::esi::Manifest;
public:
using TypeCache = map<Type::ID, unique_ptr<Type>>;
Impl(const string &jsonManifest);
auto at(const string &key) const { return manifestJson.at(key); }
@ -99,10 +101,8 @@ public:
const Type &parseType(const nlohmann::json &typeJson);
private:
BundleType *parseBundleType(const nlohmann::json &typeJson);
vector<reference_wrapper<const Type>> _typeTable;
map<Type::ID, unique_ptr<Type>> _types;
TypeCache _types;
// The parsed json.
nlohmann::json manifestJson;
@ -190,7 +190,7 @@ static ModuleInfo parseModuleInfo(const nlohmann::json &mod) {
}
//===----------------------------------------------------------------------===//
// ManifestProxy class implementation.
// Manifest::Impl class implementation.
//===----------------------------------------------------------------------===//
Manifest::Impl::Impl(const string &manifestStr) {
@ -381,7 +381,12 @@ Manifest::Impl::getBundlePorts(AppIDPath idPath,
return ret;
}
BundleType *Manifest::Impl::parseBundleType(const nlohmann::json &typeJson) {
namespace {
const Type &parseType(const nlohmann::json &typeJson,
Manifest::Impl::TypeCache &cache);
BundleType *parseBundleType(const nlohmann::json &typeJson,
Manifest::Impl::TypeCache &cache) {
assert(typeJson.at("mnemonic") == "bundle");
vector<tuple<string, BundleType::Direction, const Type &>> channels;
@ -396,34 +401,98 @@ BundleType *Manifest::Impl::parseBundleType(const nlohmann::json &typeJson) {
throw runtime_error("Malformed manifest: unknown direction '" + dirStr +
"'");
channels.emplace_back(chanJson.at("name"), dir,
parseType(chanJson["type"]));
parseType(chanJson["type"], cache));
}
return new BundleType(typeJson.at("circt_name"), channels);
}
ChannelType *parseChannelType(const nlohmann::json &typeJson,
Manifest::Impl::TypeCache &cache) {
assert(typeJson.at("mnemonic") == "channel");
return new ChannelType(typeJson.at("circt_name"),
parseType(typeJson.at("inner"), cache));
}
BitVectorType *parseInt(const nlohmann::json &typeJson,
Manifest::Impl::TypeCache &cache) {
assert(typeJson.at("mnemonic") == "int");
std::string sign = typeJson.at("signedness");
uint64_t width = typeJson.at("hw_bitwidth");
Type::ID id = typeJson.at("circt_name");
if (sign == "signed")
return new SIntType(id, width);
else if (sign == "unsigned")
return new UIntType(id, width);
else if (sign == "signless")
return new BitsType(id, width);
else
throw runtime_error("Malformed manifest: unknown sign '" + sign + "'");
}
StructType *parseStruct(const nlohmann::json &typeJson,
Manifest::Impl::TypeCache &cache) {
assert(typeJson.at("mnemonic") == "struct");
vector<tuple<string, const Type &>> fields;
for (auto &fieldJson : typeJson["fields"])
fields.emplace_back(fieldJson.at("name"),
parseType(fieldJson["type"], cache));
return new StructType(typeJson.at("circt_name"), fields);
}
ArrayType *parseArray(const nlohmann::json &typeJson,
Manifest::Impl::TypeCache &cache) {
assert(typeJson.at("mnemonic") == "array");
uint64_t size = typeJson.at("size");
return new ArrayType(typeJson.at("circt_name"),
parseType(typeJson.at("element"), cache), size);
}
using TypeParser =
std::function<Type *(const nlohmann::json &, Manifest::Impl::TypeCache &)>;
const std::map<std::string_view, TypeParser> typeParsers = {
{"bundle", parseBundleType},
{"channel", parseChannelType},
{"any",
[](const nlohmann::json &typeJson, Manifest::Impl::TypeCache &cache) {
return new AnyType(typeJson.at("circt_name"));
}},
{"int", parseInt},
{"struct", parseStruct},
{"array", parseArray},
};
// Parse a type if it doesn't already exist in the cache.
const Type &Manifest::Impl::parseType(const nlohmann::json &typeJson) {
const Type &parseType(const nlohmann::json &typeJson,
Manifest::Impl::TypeCache &cache) {
// We use the circt type string as a unique ID.
string circt_name = typeJson.at("circt_name");
// Check the cache.
auto typeF = _types.find(circt_name);
if (typeF != _types.end())
auto typeF = cache.find(circt_name);
if (typeF != cache.end())
return *typeF->second;
// Parse the type.
string mnemonic = typeJson.at("mnemonic");
Type *t;
if (mnemonic == "bundle")
t = parseBundleType(typeJson);
auto f = typeParsers.find(mnemonic);
if (f != typeParsers.end())
t = f->second(typeJson, cache);
else
// Types we don't know about are opaque.
t = new Type(circt_name);
// Insert into the cache.
_types.emplace(circt_name, unique_ptr<Type>(t));
cache.emplace(circt_name, unique_ptr<Type>(t));
return *t;
}
} // namespace
const Type &Manifest::Impl::parseType(const nlohmann::json &typeJson) {
return ::parseType(typeJson, _types);
}
void Manifest::Impl::populateTypes(const nlohmann::json &typesJson) {
for (auto &typeJson : typesJson)

View File

@ -222,8 +222,10 @@ void CosimChannelPort::disconnect() {
namespace {
class WriteCosimChannelPort : public WriteChannelPort {
public:
WriteCosimChannelPort(CosimAccelerator::Impl &impl, string name)
: cosim(make_unique<CosimChannelPort>(impl, name)) {}
WriteCosimChannelPort(CosimAccelerator::Impl &impl, const Type &type,
string name)
: WriteChannelPort(type),
cosim(make_unique<CosimChannelPort>(impl, name)) {}
virtual ~WriteCosimChannelPort() = default;
@ -243,8 +245,9 @@ void WriteCosimChannelPort::write(const void *data, size_t size) {
namespace {
class ReadCosimChannelPort : public ReadChannelPort {
public:
ReadCosimChannelPort(CosimAccelerator::Impl &impl, string name)
: cosim(new CosimChannelPort(impl, name)) {}
ReadCosimChannelPort(CosimAccelerator::Impl &impl, const Type &type,
string name)
: ReadChannelPort(type), cosim(new CosimChannelPort(impl, name)) {}
virtual ~ReadCosimChannelPort() = default;
@ -310,9 +313,9 @@ public:
ChannelPort *port;
if (BundlePort::isWrite(dir, svcDir))
port = new WriteCosimChannelPort(impl, channelName);
port = new WriteCosimChannelPort(impl, type, channelName);
else
port = new ReadCosimChannelPort(impl, channelName);
port = new ReadCosimChannelPort(impl, type, channelName);
impl.channels.emplace(port);
channels.emplace(name, *port);
}

View File

@ -155,9 +155,9 @@ private:
namespace {
class WriteTraceChannelPort : public WriteChannelPort {
public:
WriteTraceChannelPort(TraceAccelerator::Impl &impl, const AppIDPath &id,
const string &portName)
: impl(impl), id(id), portName(portName) {}
WriteTraceChannelPort(TraceAccelerator::Impl &impl, const Type &type,
const AppIDPath &id, const string &portName)
: WriteChannelPort(type), impl(impl), id(id), portName(portName) {}
virtual void write(const void *data, size_t size) override {
impl.write(id, portName, data, size);
@ -173,7 +173,8 @@ protected:
namespace {
class ReadTraceChannelPort : public ReadChannelPort {
public:
ReadTraceChannelPort(TraceAccelerator::Impl &impl) {}
ReadTraceChannelPort(TraceAccelerator::Impl &impl, const Type &type)
: ReadChannelPort(type) {}
virtual ssize_t read(void *data, size_t maxSize) override;
};
@ -201,9 +202,9 @@ public:
for (auto [name, dir, type] : bundleType.getChannels()) {
ChannelPort *port;
if (BundlePort::isWrite(dir, svcDir))
port = new WriteTraceChannelPort(impl, idPath, name);
port = new WriteTraceChannelPort(impl, type, idPath, name);
else
port = new ReadTraceChannelPort(impl);
port = new ReadTraceChannelPort(impl, type);
channels.emplace(name, *port);
impl.adoptChannelPort(port);
}

View File

@ -4,4 +4,10 @@
from .accelerator import Accelerator
from .esiCppAccel import AppID
from .esiCppAccel import (AppID, Type, BundleType, ChannelType, ArrayType,
StructType, BitsType, UIntType, SIntType)
__all__ = [
"Accelerator", "AppID", "Type", "BundleType", "ChannelType", "ArrayType",
"StructType", "BitsType", "UIntType", "SIntType"
]

View File

@ -45,6 +45,34 @@ struct polymorphic_type_hook<ChannelPort> {
// NOLINTNEXTLINE(readability-identifier-naming)
PYBIND11_MODULE(esiCppAccel, m) {
py::class_<Type>(m, "Type")
.def_property_readonly("id", &Type::getID)
.def("__repr__", [](Type &t) { return "<" + t.getID() + ">"; });
py::class_<ChannelType, Type>(m, "ChannelType")
.def_property_readonly("inner", &ChannelType::getInner,
py::return_value_policy::reference_internal);
py::enum_<BundleType::Direction>(m, "Direction")
.value("To", BundleType::Direction::To)
.value("From", BundleType::Direction::From)
.export_values();
py::class_<BundleType, Type>(m, "BundleType")
.def_property_readonly("channels", &BundleType::getChannels,
py::return_value_policy::reference_internal);
py::class_<AnyType, Type>(m, "AnyType");
py::class_<BitVectorType, Type>(m, "BitVectorType")
.def_property_readonly("width", &BitVectorType::getWidth);
py::class_<BitsType, BitVectorType>(m, "BitsType");
py::class_<IntegerType, BitVectorType>(m, "IntegerType");
py::class_<SIntType, IntegerType>(m, "SIntType");
py::class_<UIntType, IntegerType>(m, "UIntType");
py::class_<StructType, Type>(m, "StructType")
.def_property_readonly("fields", &StructType::getFields,
py::return_value_policy::reference_internal);
py::class_<ArrayType, Type>(m, "ArrayType")
.def_property_readonly("element", &ArrayType::getElementType,
py::return_value_policy::reference_internal)
.def_property_readonly("size", &ArrayType::getSize);
py::class_<ModuleInfo>(m, "ModuleInfo")
.def_property_readonly("name", [](ModuleInfo &info) { return info.name; })
.def_property_readonly("summary",
@ -96,7 +124,9 @@ PYBIND11_MODULE(esiCppAccel, m) {
});
py::class_<ChannelPort>(m, "ChannelPort")
.def("connect", &ChannelPort::connect);
.def("connect", &ChannelPort::connect)
.def_property_readonly("type", &ChannelPort::getType,
py::return_value_policy::reference_internal);
py::class_<WriteChannelPort, ChannelPort>(m, "WriteChannelPort")
.def("write", [](WriteChannelPort &p, std::vector<uint8_t> data) {
@ -151,10 +181,6 @@ PYBIND11_MODULE(esiCppAccel, m) {
[](Accelerator &acc) { return acc.getService<services::MMIO>({}); },
py::return_value_policy::reference_internal);
py::class_<Type>(m, "Type")
.def_property_readonly("id", &Type::getID)
.def("__repr__", [](Type &t) { return "<" + t.getID() + ">"; });
py::class_<Manifest>(m, "Manifest")
.def(py::init<std::string>())
.def_property_readonly("api_version", &Manifest::getApiVersion)

View File

@ -7,9 +7,11 @@ from __future__ import annotations
import typing
__all__ = [
'Accelerator', 'AppID', 'BundlePort', 'ChannelPort', 'Design', 'Instance',
'MMIO', 'Manifest', 'ModuleInfo', 'ReadChannelPort', 'SysInfo', 'Type',
'WriteChannelPort'
'Accelerator', 'AnyType', 'AppID', 'ArrayType', 'BitVectorType', 'BitsType',
'BundlePort', 'BundleType', 'ChannelPort', 'ChannelType', 'Design',
'Direction', 'From', 'Instance', 'IntegerType', 'MMIO', 'Manifest',
'ModuleInfo', 'ReadChannelPort', 'SIntType', 'StructType', 'SysInfo', 'To',
'Type', 'UIntType', 'WriteChannelPort'
]
@ -25,6 +27,10 @@ class Accelerator:
...
class AnyType(Type):
pass
class AppID:
def __eq__(self, arg0: AppID) -> bool:
@ -48,6 +54,28 @@ class AppID:
...
class ArrayType(Type):
@property
def element(self) -> Type:
...
@property
def size(self) -> int:
...
class BitVectorType(Type):
@property
def width(self) -> int:
...
class BitsType(BitVectorType):
pass
class BundlePort:
def getRead(self, arg0: str) -> ReadChannelPort:
@ -65,11 +93,29 @@ class BundlePort:
...
class BundleType(Type):
@property
def channels(self) -> list[tuple[str, Direction, Type]]:
...
class ChannelPort:
def connect(self) -> None:
...
@property
def type(self) -> Type:
...
class ChannelType(Type):
@property
def inner(self) -> Type:
...
class Design:
@ -86,6 +132,59 @@ class Design:
...
class Direction:
"""
Members:
To
From
"""
From: typing.ClassVar[Direction] # value = <Direction.From: 1>
To: typing.ClassVar[Direction] # value = <Direction.To: 0>
__members__: typing.ClassVar[dict[
str,
Direction]] # value = {'To': <Direction.To: 0>, 'From': <Direction.From: 1>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: int) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
def __repr__(self) -> str:
...
def __setstate__(self, state: int) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
class Instance(Design):
@property
@ -93,6 +192,10 @@ class Instance(Design):
...
class IntegerType(BitVectorType):
pass
class MMIO:
def read(self, arg0: int) -> int:
@ -151,6 +254,17 @@ class ReadChannelPort(ChannelPort):
...
class SIntType(IntegerType):
pass
class StructType(Type):
@property
def fields(self) -> list[tuple[str, Type]]:
...
class SysInfo:
def esi_version(self) -> int:
@ -170,7 +284,15 @@ class Type:
...
class UIntType(IntegerType):
pass
class WriteChannelPort(ChannelPort):
def write(self, arg0: list[int]) -> None:
...
From: Direction # value = <Direction.From: 1>
To: Direction # value = <Direction.To: 0>