From ac25e8628c443cddd841c6c91d1c9e23e88969e5 Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere Date: Thu, 10 Dec 2020 09:35:12 -0800 Subject: [PATCH] [lldb] Deal gracefully with concurrency in the API instrumentation. Prevent lldb from crashing when multiple threads are concurrently accessing the SB API with reproducer capture enabled. The API instrumentation records both the input arguments and the return value, but it cannot block for the duration of the API call. Therefore we introduce a sequence number that allows to to correlate the function with its result and add locking to ensure those two parts are emitted atomically. Using the sequence number, we can detect situations where the return value does not succeed the function call, in which case we print an error saying that concurrency is not (currently) supported. In the future we might attempt to be smarter and read ahead until we've found the return value matching the current call. Differential revision: https://reviews.llvm.org/D92820 --- .../lldb/Utility/ReproducerInstrumentation.h | 41 ++++++++++++++ .../Utility/ReproducerInstrumentation.cpp | 29 +++++++++- .../Utility/ReproducerInstrumentationTest.cpp | 55 ++++++++++++++++++- 3 files changed, 120 insertions(+), 5 deletions(-) diff --git a/lldb/include/lldb/Utility/ReproducerInstrumentation.h b/lldb/include/lldb/Utility/ReproducerInstrumentation.h index e4c31522c4fc..c8a98adf85c7 100644 --- a/lldb/include/lldb/Utility/ReproducerInstrumentation.h +++ b/lldb/include/lldb/Utility/ReproducerInstrumentation.h @@ -333,6 +333,7 @@ public: } template const T &HandleReplayResult(const T &t) { + CheckSequence(Deserialize()); unsigned result = Deserialize(); if (is_trivially_serializable::value) return t; @@ -342,6 +343,7 @@ public: /// Store the returned value in the index-to-object mapping. template T &HandleReplayResult(T &t) { + CheckSequence(Deserialize()); unsigned result = Deserialize(); if (is_trivially_serializable::value) return t; @@ -351,6 +353,7 @@ public: /// Store the returned value in the index-to-object mapping. template T *HandleReplayResult(T *t) { + CheckSequence(Deserialize()); unsigned result = Deserialize(); if (is_trivially_serializable::value) return t; @@ -360,6 +363,7 @@ public: /// All returned types are recorded, even when the function returns a void. /// The latter requires special handling. void HandleReplayResultVoid() { + CheckSequence(Deserialize()); unsigned result = Deserialize(); assert(result == 0); (void)result; @@ -369,6 +373,10 @@ public: return m_index_to_object.GetAllObjects(); } + void SetExpectedSequence(unsigned sequence) { + m_expected_sequence = sequence; + } + private: template T Read(ValueTag) { assert(HasData(sizeof(T))); @@ -410,11 +418,17 @@ private: return *(new UnderlyingT(Deserialize())); } + /// Verify that the given sequence number matches what we expect. + void CheckSequence(unsigned sequence); + /// Mapping of indices to objects. IndexToObject m_index_to_object; /// Buffer containing the serialized data. llvm::StringRef m_buffer; + + /// The result's expected sequence number. + llvm::Optional m_expected_sequence; }; /// Partial specialization for C-style strings. We read the string value @@ -745,12 +759,15 @@ public: if (!ShouldCapture()) return; + std::lock_guard lock(g_mutex); + unsigned sequence = GetSequenceNumber(); unsigned id = registry.GetID(uintptr_t(f)); #ifdef LLDB_REPRO_INSTR_TRACE Log(id); #endif + serializer.SerializeAll(sequence); serializer.SerializeAll(id); serializer.SerializeAll(args...); @@ -758,6 +775,7 @@ public: typename std::remove_reference::type>::type>::value) { m_result_recorded = false; } else { + serializer.SerializeAll(sequence); serializer.SerializeAll(0); m_result_recorded = true; } @@ -771,16 +789,20 @@ public: if (!ShouldCapture()) return; + std::lock_guard lock(g_mutex); + unsigned sequence = GetSequenceNumber(); unsigned id = registry.GetID(uintptr_t(f)); #ifdef LLDB_REPRO_INSTR_TRACE Log(id); #endif + serializer.SerializeAll(sequence); serializer.SerializeAll(id); serializer.SerializeAll(args...); // Record result. + serializer.SerializeAll(sequence); serializer.SerializeAll(0); m_result_recorded = true; } @@ -806,7 +828,9 @@ public: if (update_boundary) UpdateBoundary(); if (m_serializer && ShouldCapture()) { + std::lock_guard lock(g_mutex); assert(!m_result_recorded); + m_serializer->SerializeAll(GetSequenceNumber()); m_serializer->SerializeAll(r); m_result_recorded = true; } @@ -816,6 +840,7 @@ public: template Result Replay(Deserializer &deserializer, Registry ®istry, uintptr_t addr, bool update_boundary) { + deserializer.SetExpectedSequence(deserializer.Deserialize()); unsigned actual_id = registry.GetID(addr); unsigned id = deserializer.Deserialize(); registry.CheckID(id, actual_id); @@ -826,6 +851,7 @@ public: } void Replay(Deserializer &deserializer, Registry ®istry, uintptr_t addr) { + deserializer.SetExpectedSequence(deserializer.Deserialize()); unsigned actual_id = registry.GetID(addr); unsigned id = deserializer.Deserialize(); registry.CheckID(id, actual_id); @@ -846,6 +872,9 @@ public: static void PrivateThread() { g_global_boundary = true; } private: + static unsigned GetNextSequenceNumber() { return g_sequence++; } + unsigned GetSequenceNumber() const; + template friend struct replay; void UpdateBoundary() { if (m_local_boundary) @@ -871,8 +900,17 @@ private: /// Whether the return value was recorded explicitly. bool m_result_recorded; + /// The sequence number for this pair of function and result. + unsigned m_sequence; + /// Whether we're currently across the API boundary. static thread_local bool g_global_boundary; + + /// Global mutex to protect concurrent access. + static std::mutex g_mutex; + + /// Unique, monotonically increasing sequence number. + static std::atomic g_sequence; }; /// To be used as the "Runtime ID" of a constructor. It also invokes the @@ -1014,6 +1052,7 @@ struct invoke_char_ptr { static Result replay(Recorder &recorder, Deserializer &deserializer, Registry ®istry, char *str) { + deserializer.SetExpectedSequence(deserializer.Deserialize()); deserializer.Deserialize(); Class *c = deserializer.Deserialize(); deserializer.Deserialize(); @@ -1035,6 +1074,7 @@ struct invoke_char_ptr { static Result replay(Recorder &recorder, Deserializer &deserializer, Registry ®istry, char *str) { + deserializer.SetExpectedSequence(deserializer.Deserialize()); deserializer.Deserialize(); Class *c = deserializer.Deserialize(); deserializer.Deserialize(); @@ -1055,6 +1095,7 @@ struct invoke_char_ptr { static Result replay(Recorder &recorder, Deserializer &deserializer, Registry ®istry, char *str) { + deserializer.SetExpectedSequence(deserializer.Deserialize()); deserializer.Deserialize(); deserializer.Deserialize(); size_t l = deserializer.Deserialize(); diff --git a/lldb/source/Utility/ReproducerInstrumentation.cpp b/lldb/source/Utility/ReproducerInstrumentation.cpp index 626120c9d71a..b274a10c98fd 100644 --- a/lldb/source/Utility/ReproducerInstrumentation.cpp +++ b/lldb/source/Utility/ReproducerInstrumentation.cpp @@ -8,6 +8,7 @@ #include "lldb/Utility/ReproducerInstrumentation.h" #include "lldb/Utility/Reproducer.h" +#include #include #include #include @@ -84,6 +85,16 @@ template <> const char **Deserializer::Deserialize() { return r; } +void Deserializer::CheckSequence(unsigned sequence) { + if (m_expected_sequence && *m_expected_sequence != sequence) + llvm::report_fatal_error( + "The result does not match the preceding " + "function. This is probably the result of concurrent " + "use of the SB API during capture, which is currently not " + "supported."); + m_expected_sequence.reset(); +} + bool Registry::Replay(const FileSpec &file) { auto error_or_file = llvm::MemoryBuffer::getFile(file.GetPath()); if (auto err = error_or_file.getError()) @@ -107,6 +118,7 @@ bool Registry::Replay(Deserializer &deserializer) { setvbuf(stdout, nullptr, _IONBF, 0); while (deserializer.HasData(1)) { + unsigned sequence = deserializer.Deserialize(); unsigned id = deserializer.Deserialize(); #ifndef LLDB_REPRO_INSTR_TRACE @@ -115,6 +127,7 @@ bool Registry::Replay(Deserializer &deserializer) { llvm::errs() << "Replaying " << id << ": " << GetSignature(id) << "\n"; #endif + deserializer.SetExpectedSequence(sequence); GetReplayer(id)->operator()(deserializer); } @@ -181,21 +194,24 @@ unsigned ObjectToIndex::GetIndexForObjectImpl(const void *object) { Recorder::Recorder() : m_serializer(nullptr), m_pretty_func(), m_pretty_args(), - m_local_boundary(false), m_result_recorded(true) { + m_local_boundary(false), m_result_recorded(true), + m_sequence(std::numeric_limits::max()) { if (!g_global_boundary) { g_global_boundary = true; m_local_boundary = true; + m_sequence = GetNextSequenceNumber(); } } Recorder::Recorder(llvm::StringRef pretty_func, std::string &&pretty_args) : m_serializer(nullptr), m_pretty_func(pretty_func), m_pretty_args(pretty_args), m_local_boundary(false), - m_result_recorded(true) { + m_result_recorded(true), + m_sequence(std::numeric_limits::max()) { if (!g_global_boundary) { g_global_boundary = true; m_local_boundary = true; - + m_sequence = GetNextSequenceNumber(); LLDB_LOG(GetLogIfAllCategoriesSet(LIBLLDB_LOG_API), "{0} ({1})", m_pretty_func, m_pretty_args); } @@ -206,6 +222,11 @@ Recorder::~Recorder() { UpdateBoundary(); } +unsigned Recorder::GetSequenceNumber() const { + assert(m_sequence != std::numeric_limits::max()); + return m_sequence; +} + void InstrumentationData::Initialize(Serializer &serializer, Registry ®istry) { InstanceImpl().emplace(serializer, registry); @@ -228,3 +249,5 @@ llvm::Optional &InstrumentationData::InstanceImpl() { } thread_local bool lldb_private::repro::Recorder::g_global_boundary = false; +std::atomic lldb_private::repro::Recorder::g_sequence; +std::mutex lldb_private::repro::Recorder::g_mutex; diff --git a/lldb/unittests/Utility/ReproducerInstrumentationTest.cpp b/lldb/unittests/Utility/ReproducerInstrumentationTest.cpp index 1ed00a77249f..e9f6fcf34e17 100644 --- a/lldb/unittests/Utility/ReproducerInstrumentationTest.cpp +++ b/lldb/unittests/Utility/ReproducerInstrumentationTest.cpp @@ -576,8 +576,11 @@ TEST(SerializationRountripTest, SerializeDeserializeObjectPointer) { std::string str; llvm::raw_string_ostream os(str); + unsigned sequence = 123; + Serializer serializer(os); - serializer.SerializeAll(static_cast(1), static_cast(2)); + serializer.SerializeAll(sequence, static_cast(1)); + serializer.SerializeAll(sequence, static_cast(2)); serializer.SerializeAll(&foo, &bar); llvm::StringRef buffer(os.str()); @@ -597,8 +600,11 @@ TEST(SerializationRountripTest, SerializeDeserializeObjectReference) { std::string str; llvm::raw_string_ostream os(str); + unsigned sequence = 123; + Serializer serializer(os); - serializer.SerializeAll(static_cast(1), static_cast(2)); + serializer.SerializeAll(sequence, static_cast(1)); + serializer.SerializeAll(sequence, static_cast(2)); serializer.SerializeAll(foo, bar); llvm::StringRef buffer(os.str()); @@ -1114,3 +1120,48 @@ TEST(PassiveReplayTest, InstrumentedBarPtr) { bar.Validate(); } } + +TEST(RecordReplayTest, ValidSequence) { + std::string str; + llvm::raw_string_ostream os(str); + + { + auto data = TestInstrumentationDataRAII::GetRecordingData(os); + + unsigned sequence = 1; + int (*f)() = &lldb_private::repro::invoke::method< + InstrumentedFoo::F>::record; + unsigned id = g_registry->GetID(uintptr_t(f)); + g_serializer->SerializeAll(sequence, id); + + unsigned result = 0; + g_serializer->SerializeAll(sequence, result); + } + + TestingRegistry registry; + Deserializer deserializer(os.str()); + registry.Replay(deserializer); +} + +TEST(RecordReplayTest, InvalidSequence) { + std::string str; + llvm::raw_string_ostream os(str); + + { + auto data = TestInstrumentationDataRAII::GetRecordingData(os); + + unsigned sequence = 1; + int (*f)() = &lldb_private::repro::invoke::method< + InstrumentedFoo::F>::record; + unsigned id = g_registry->GetID(uintptr_t(f)); + g_serializer->SerializeAll(sequence, id); + + unsigned result = 0; + unsigned invalid_sequence = 2; + g_serializer->SerializeAll(invalid_sequence, result); + } + + TestingRegistry registry; + Deserializer deserializer(os.str()); + EXPECT_DEATH(registry.Replay(deserializer), ""); +}