[ORC] Add TaskDispatch API and thread it through ExecutorProcessControl.

ExecutorProcessControl objects will now have a TaskDispatcher member which
should be used to dispatch work (in particular, handling incoming packets in
the implementation of remote EPC implementations like SimpleRemoteEPC).

The GenericNamedTask template can be used to wrap function objects that are
callable as 'void()' (along with an optional name to describe the task).
The makeGenericNamedTask functions can be used to create GenericNamedTask
instances without having to name the function object type.

In a future patch ExecutionSession will be updated to use the
ExecutorProcessControl's dispatcher, instead of its DispatchTaskFunction.
This commit is contained in:
Lang Hames 2021-10-08 17:12:06 -07:00
parent 77bc3ba365
commit f341161689
12 changed files with 258 additions and 34 deletions

View File

@ -21,6 +21,7 @@
#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
#include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h"
#include "llvm/ExecutionEngine/Orc/TaskDispatch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ExtensibleRTTI.h"
@ -1254,21 +1255,6 @@ public:
const DenseMap<JITDylib *, SymbolLookupSet> &InitSyms);
};
/// Represents an abstract task for ORC to run.
class Task : public RTTIExtends<Task, RTTIRoot> {
public:
static char ID;
/// Description of the task to be performed. Used for logging.
virtual void printDescription(raw_ostream &OS) = 0;
/// Run the task.
virtual void run() = 0;
private:
void anchor() override;
};
/// A materialization task.
class MaterializationTask : public RTTIExtends<MaterializationTask, Task> {
public:

View File

@ -20,6 +20,7 @@
#include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h"
#include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h"
#include "llvm/ExecutionEngine/Orc/SymbolStringPool.h"
#include "llvm/ExecutionEngine/Orc/TaskDispatch.h"
#include "llvm/Support/DynamicLibrary.h"
#include "llvm/Support/MSVCErrorWorkarounds.h"
@ -121,6 +122,10 @@ public:
ExecutorAddr JITDispatchContext;
};
ExecutorProcessControl(std::shared_ptr<SymbolStringPool> SSP,
std::unique_ptr<TaskDispatcher> D)
: SSP(std::move(SSP)), D(std::move(D)) {}
virtual ~ExecutorProcessControl();
/// Return the ExecutionSession associated with this instance.
@ -136,6 +141,8 @@ public:
/// Return a shared pointer to the SymbolStringPool for this instance.
std::shared_ptr<SymbolStringPool> getSymbolStringPool() const { return SSP; }
TaskDispatcher &getDispatcher() { return *D; }
/// Return the Triple for the target process.
const Triple &getTargetTriple() const { return TargetTriple; }
@ -264,10 +271,9 @@ public:
virtual Error disconnect() = 0;
protected:
ExecutorProcessControl(std::shared_ptr<SymbolStringPool> SSP)
: SSP(std::move(SSP)) {}
std::shared_ptr<SymbolStringPool> SSP;
std::unique_ptr<TaskDispatcher> D;
ExecutionSession *ES = nullptr;
Triple TargetTriple;
unsigned PageSize = 0;
@ -284,9 +290,12 @@ class UnsupportedExecutorProcessControl : public ExecutorProcessControl {
public:
UnsupportedExecutorProcessControl(
std::shared_ptr<SymbolStringPool> SSP = nullptr,
std::unique_ptr<TaskDispatcher> D = nullptr,
const std::string &TT = "", unsigned PageSize = 0)
: ExecutorProcessControl(SSP ? std::move(SSP)
: std::make_shared<SymbolStringPool>()) {
: std::make_shared<SymbolStringPool>(),
D ? std::move(D)
: std::make_unique<InPlaceTaskDispatcher>()) {
this->TargetTriple = Triple(TT);
this->PageSize = PageSize;
}
@ -320,8 +329,9 @@ class SelfExecutorProcessControl
private ExecutorProcessControl::MemoryAccess {
public:
SelfExecutorProcessControl(
std::shared_ptr<SymbolStringPool> SSP, Triple TargetTriple,
unsigned PageSize, std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr);
std::shared_ptr<SymbolStringPool> SSP, std::unique_ptr<TaskDispatcher> D,
Triple TargetTriple, unsigned PageSize,
std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr);
/// Create a SelfExecutorProcessControl with the given symbol string pool and
/// memory manager.
@ -330,6 +340,7 @@ public:
/// be created and used by default.
static Expected<std::unique_ptr<SelfExecutorProcessControl>>
Create(std::shared_ptr<SymbolStringPool> SSP = nullptr,
std::unique_ptr<TaskDispatcher> D = nullptr,
std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr = nullptr);
Expected<tpctypes::DylibHandle> loadDylib(const char *DylibPath) override;

View File

@ -34,9 +34,11 @@ public:
/// Create a SimpleRemoteEPC using the given transport type and args.
template <typename TransportT, typename... TransportTCtorArgTs>
static Expected<std::unique_ptr<SimpleRemoteEPC>>
Create(TransportTCtorArgTs &&...TransportTCtorArgs) {
Create(std::unique_ptr<TaskDispatcher> D,
TransportTCtorArgTs &&...TransportTCtorArgs) {
std::unique_ptr<SimpleRemoteEPC> SREPC(
new SimpleRemoteEPC(std::make_shared<SymbolStringPool>()));
new SimpleRemoteEPC(std::make_shared<SymbolStringPool>(),
std::move(D)));
auto T = TransportT::Create(
*SREPC, std::forward<TransportTCtorArgTs>(TransportTCtorArgs)...);
if (!T)
@ -79,8 +81,9 @@ protected:
virtual Expected<std::unique_ptr<MemoryAccess>> createMemoryAccess();
private:
SimpleRemoteEPC(std::shared_ptr<SymbolStringPool> SSP)
: ExecutorProcessControl(std::move(SSP)) {}
SimpleRemoteEPC(std::shared_ptr<SymbolStringPool> SSP,
std::unique_ptr<TaskDispatcher> D)
: ExecutorProcessControl(std::move(SSP), std::move(D)) {}
Error sendMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
ExecutorAddr TagAddr, ArrayRef<char> ArgBytes);

View File

@ -0,0 +1,129 @@
//===--------- TaskDispatch.h - ORC task dispatch utils ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Task and TaskDispatch classes.
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_EXECUTIONENGINE_ORC_TASKDISPATCH_H
#define LLVM_EXECUTIONENGINE_ORC_TASKDISPATCH_H
#include "llvm/Config/llvm-config.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include <string>
#if LLVM_ENABLE_THREADS
#include <condition_variable>
#include <mutex>
#include <thread>
#endif
namespace llvm {
namespace orc {
/// Represents an abstract task for ORC to run.
class Task : public RTTIExtends<Task, RTTIRoot> {
public:
static char ID;
virtual ~Task() {}
/// Description of the task to be performed. Used for logging.
virtual void printDescription(raw_ostream &OS) = 0;
/// Run the task.
virtual void run() = 0;
private:
void anchor() override;
};
/// Base class for generic tasks.
class GenericNamedTask : public RTTIExtends<GenericNamedTask, Task> {
public:
static char ID;
static const char *DefaultDescription;
};
/// Generic task implementation.
template <typename FnT> class GenericNamedTaskImpl : public GenericNamedTask {
public:
GenericNamedTaskImpl(FnT &&Fn, std::string DescBuffer)
: Fn(std::forward<FnT>(Fn)), Desc(DescBuffer.c_str()),
DescBuffer(std::move(DescBuffer)) {}
GenericNamedTaskImpl(FnT &&Fn, const char *Desc)
: Fn(std::forward<FnT>(Fn)), Desc(Desc) {
assert(Desc && "Description cannot be null");
}
void printDescription(raw_ostream &OS) override { OS << Desc; }
void run() override { Fn(); }
private:
FnT Fn;
const char *Desc;
std::string DescBuffer;
};
/// Create a generic named task from a std::string description.
template <typename FnT>
std::unique_ptr<GenericNamedTask> makeGenericNamedTask(FnT &&Fn,
std::string Desc) {
return std::make_unique<GenericNamedTaskImpl<FnT>>(std::forward<FnT>(Fn),
std::move(Desc));
}
/// Create a generic named task from a const char * description.
template <typename FnT>
std::unique_ptr<GenericNamedTask>
makeGenericNamedTask(FnT &&Fn, const char *Desc = nullptr) {
if (!Desc)
Desc = GenericNamedTask::DefaultDescription;
return std::make_unique<GenericNamedTaskImpl<FnT>>(std::forward<FnT>(Fn),
Desc);
}
/// Abstract base for classes that dispatch ORC Tasks.
class TaskDispatcher {
public:
virtual ~TaskDispatcher();
/// Run the given task.
virtual void dispatch(std::unique_ptr<Task> T) = 0;
/// Called by ExecutionSession. Waits until all tasks have completed.
virtual void shutdown() = 0;
};
/// Runs all tasks on the current thread.
class InPlaceTaskDispatcher : public TaskDispatcher {
public:
void dispatch(std::unique_ptr<Task> T) override;
void shutdown() override;
};
#if LLVM_ENABLE_THREADS
class DynamicThreadPoolTaskDispatcher : public TaskDispatcher {
public:
void dispatch(std::unique_ptr<Task> T) override;
void shutdown() override;
private:
std::mutex DispatchMutex;
bool Running = true;
size_t Outstanding = 0;
std::condition_variable OutstandingCV;
};
#endif // LLVM_ENABLE_THREADS
} // End namespace orc
} // End namespace llvm
#endif // LLVM_EXECUTIONENGINE_ORC_TASKDISPATCH_H

View File

@ -32,6 +32,7 @@ add_llvm_component_library(LLVMOrcJIT
Speculation.cpp
SpeculateAnalyses.cpp
ExecutorProcessControl.cpp
TaskDispatch.cpp
ThreadSafeModule.cpp
ADDITIONAL_HEADER_DIRS
${LLVM_MAIN_INCLUDE_DIR}/llvm/ExecutionEngine/Orc

View File

@ -29,7 +29,6 @@ char SymbolsNotFound::ID = 0;
char SymbolsCouldNotBeRemoved::ID = 0;
char MissingSymbolDefinitions::ID = 0;
char UnexpectedSymbolDefinitions::ID = 0;
char Task::ID = 0;
char MaterializationTask::ID = 0;
RegisterDependenciesFunction NoDependenciesToRegister =
@ -1799,8 +1798,6 @@ void Platform::lookupInitSymbolsAsync(
}
}
void Task::anchor() {}
void MaterializationTask::printDescription(raw_ostream &OS) {
OS << "Materialization task: " << MU->getName() << " in "
<< MR->getTargetJITDylib().getName();

View File

@ -24,9 +24,10 @@ ExecutorProcessControl::MemoryAccess::~MemoryAccess() {}
ExecutorProcessControl::~ExecutorProcessControl() {}
SelfExecutorProcessControl::SelfExecutorProcessControl(
std::shared_ptr<SymbolStringPool> SSP, Triple TargetTriple,
unsigned PageSize, std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr)
: ExecutorProcessControl(std::move(SSP)) {
std::shared_ptr<SymbolStringPool> SSP, std::unique_ptr<TaskDispatcher> D,
Triple TargetTriple, unsigned PageSize,
std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr)
: ExecutorProcessControl(std::move(SSP), std::move(D)) {
OwnedMemMgr = std::move(MemMgr);
if (!OwnedMemMgr)
@ -45,11 +46,20 @@ SelfExecutorProcessControl::SelfExecutorProcessControl(
Expected<std::unique_ptr<SelfExecutorProcessControl>>
SelfExecutorProcessControl::Create(
std::shared_ptr<SymbolStringPool> SSP,
std::unique_ptr<TaskDispatcher> D,
std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr) {
if (!SSP)
SSP = std::make_shared<SymbolStringPool>();
if (!D) {
#if LLVM_ENABLE_THREADS
D = std::make_unique<DynamicThreadPoolTaskDispatcher>();
#else
D = std::make_unique<InPlaceTaskDispatcher>();
#endif
}
auto PageSize = sys::Process::getPageSize();
if (!PageSize)
return PageSize.takeError();
@ -57,7 +67,8 @@ SelfExecutorProcessControl::Create(
Triple TT(sys::getProcessTriple());
return std::make_unique<SelfExecutorProcessControl>(
std::move(SSP), std::move(TT), *PageSize, std::move(MemMgr));
std::move(SSP), std::move(D), std::move(TT), *PageSize,
std::move(MemMgr));
}
Expected<tpctypes::DylibHandle>

View File

@ -0,0 +1,48 @@
//===------------ TaskDispatch.cpp - ORC task dispatch utils --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/ExecutionEngine/Orc/TaskDispatch.h"
namespace llvm {
namespace orc {
char Task::ID = 0;
char GenericNamedTask::ID = 0;
const char *GenericNamedTask::DefaultDescription = "Generic Task";
void Task::anchor() {}
TaskDispatcher::~TaskDispatcher() {}
void InPlaceTaskDispatcher::dispatch(std::unique_ptr<Task> T) { T->run(); }
void InPlaceTaskDispatcher::shutdown() {}
#if LLVM_ENABLE_THREADS
void DynamicThreadPoolTaskDispatcher::dispatch(std::unique_ptr<Task> T) {
{
std::lock_guard<std::mutex> Lock(DispatchMutex);
++Outstanding;
}
std::thread([this, T = std::move(T)]() mutable {
T->run();
std::lock_guard<std::mutex> Lock(DispatchMutex);
--Outstanding;
OutstandingCV.notify_all();
}).detach();
}
void DynamicThreadPoolTaskDispatcher::shutdown() {
std::unique_lock<std::mutex> Lock(DispatchMutex);
Running = false;
OutstandingCV.wait(Lock, [this]() { return Outstanding == 0; });
}
#endif
} // namespace orc
} // namespace llvm

View File

@ -1150,6 +1150,7 @@ Expected<std::unique_ptr<orc::ExecutorProcessControl>> launchRemote() {
// Return a SimpleRemoteEPC instance connected to our end of the pipes.
return orc::SimpleRemoteEPC::Create<orc::FDSimpleRemoteEPCTransport>(
std::make_unique<llvm::orc::InPlaceTaskDispatcher>(),
PipeFD[1][0], PipeFD[0][1]);
#endif
}

View File

@ -718,6 +718,7 @@ static Expected<std::unique_ptr<ExecutorProcessControl>> launchExecutor() {
close(FromExecutor[WriteEnd]);
return SimpleRemoteEPC::Create<FDSimpleRemoteEPCTransport>(
std::make_unique<DynamicThreadPoolTaskDispatcher>(),
FromExecutor[ReadEnd], ToExecutor[WriteEnd]);
#endif
}
@ -795,7 +796,8 @@ static Expected<std::unique_ptr<ExecutorProcessControl>> connectToExecutor() {
if (!SockFD)
return SockFD.takeError();
return SimpleRemoteEPC::Create<FDSimpleRemoteEPCTransport>(*SockFD, *SockFD);
return SimpleRemoteEPC::Create<FDSimpleRemoteEPCTransport>(
std::make_unique<DynamicThreadPoolTaskDispatcher>(), *SockFD, *SockFD);
#endif
}
@ -832,8 +834,9 @@ Expected<std::unique_ptr<Session>> Session::Create(Triple TT) {
if (!PageSize)
return PageSize.takeError();
EPC = std::make_unique<SelfExecutorProcessControl>(
std::make_shared<SymbolStringPool>(), std::move(TT), *PageSize,
createMemoryManager());
std::make_shared<SymbolStringPool>(),
std::make_unique<DynamicThreadPoolTaskDispatcher>(),
std::move(TT), *PageSize, createMemoryManager());
}
Error Err = Error::success();

View File

@ -32,6 +32,7 @@ add_llvm_unittest(OrcJITTests
SimpleExecutorMemoryManagerTest.cpp
SimplePackedSerializationTest.cpp
SymbolStringPoolTest.cpp
TaskDispatchTest.cpp
ThreadSafeModuleTest.cpp
WrapperFunctionUtilsTest.cpp
)

View File

@ -0,0 +1,33 @@
//===----------- TaskDispatchTest.cpp - Test TaskDispatch APIs ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/ExecutionEngine/Orc/TaskDispatch.h"
#include "gtest/gtest.h"
#include <future>
using namespace llvm;
using namespace llvm::orc;
TEST(InPlaceTaskDispatchTest, GenericNamedTask) {
auto D = std::make_unique<InPlaceTaskDispatcher>();
bool B = false;
D->dispatch(makeGenericNamedTask([&]() { B = true; }));
EXPECT_TRUE(B);
}
#if LLVM_ENABLE_THREADS
TEST(DynamicThreadPoolDispatchTest, GenericNamedTask) {
auto D = std::make_unique<DynamicThreadPoolTaskDispatcher>();
std::promise<bool> P;
auto F = P.get_future();
D->dispatch(makeGenericNamedTask(
[P = std::move(P)]() mutable { P.set_value(true); }));
EXPECT_TRUE(F.get());
}
#endif