[codeview] Fix a nasty use after free.

StreamRef was designed to be a thin wrapper over an abstract
stream interface that could itself be treated the same as any
other stream interface.  For this reason, it inherited publicly
from StreamInterface, and stored a StreamInterface* internally.

But StreamRef was also designed to be lightweight and easily
copyable, similar to ArrayRef.  This led to two misuses of
the classes.

1) When creating a StreamRef A from another StreamRef B, it was
   possible to end up with A storing a pointer to B, even when
   B was a temporary object, leading to use after free.
2) The above situation could be repeated ad nauseum, so that
   A stores a pointer to B, which itself stores a pointer to
   another StreamRef C, and so on and so on, creating an
   unnecessarily level of nesting depth.

This patch removes the public inheritance relationship between
StreamRef and StreamInterface, making it so that we can never
accidentally convert a StreamRef to a StreamInterface.

llvm-svn: 271570
This commit is contained in:
Zachary Turner 2016-06-02 19:51:48 +00:00
parent e37d13b9ec
commit f4e9c9ac08
6 changed files with 13 additions and 10 deletions

View File

@ -27,7 +27,7 @@ template <typename Kind> struct CVRecord {
}; };
template <typename Kind> struct VarStreamArrayExtractor<CVRecord<Kind>> { template <typename Kind> struct VarStreamArrayExtractor<CVRecord<Kind>> {
Error operator()(const StreamInterface &Stream, uint32_t &Len, Error operator()(StreamRef Stream, uint32_t &Len,
CVRecord<Kind> &Item) const { CVRecord<Kind> &Item) const {
const RecordPrefix *Prefix = nullptr; const RecordPrefix *Prefix = nullptr;
StreamReader Reader(Stream); StreamReader Reader(Stream);

View File

@ -29,8 +29,7 @@ template <typename T> struct VarStreamArrayExtractor {
// with the following method implemented. On output return `Len` should // with the following method implemented. On output return `Len` should
// contain the number of bytes to consume from the stream, and `Item` should // contain the number of bytes to consume from the stream, and `Item` should
// be initialized with the proper value. // be initialized with the proper value.
Error operator()(const StreamInterface &Stream, uint32_t &Len, Error operator()(StreamRef Stream, uint32_t &Len, T &Item) const = delete;
T &Item) const = delete;
}; };
/// VarStreamArray represents an array of variable length records backed by a /// VarStreamArray represents an array of variable length records backed by a

View File

@ -26,7 +26,7 @@ class StreamRef;
class StreamReader { class StreamReader {
public: public:
StreamReader(const StreamInterface &S); StreamReader(StreamRef Stream);
Error readBytes(ArrayRef<uint8_t> &Buffer, uint32_t Size); Error readBytes(ArrayRef<uint8_t> &Buffer, uint32_t Size);
Error readInteger(uint16_t &Dest); Error readInteger(uint16_t &Dest);
@ -72,7 +72,7 @@ public:
return make_error<CodeViewError>(cv_error_code::corrupt_record); return make_error<CodeViewError>(cv_error_code::corrupt_record);
if (Offset + Length > Stream.getLength()) if (Offset + Length > Stream.getLength())
return make_error<CodeViewError>(cv_error_code::insufficient_buffer); return make_error<CodeViewError>(cv_error_code::insufficient_buffer);
StreamRef View(Stream, Offset, Length); StreamRef View = Stream.slice(Offset, Length);
Array = FixedStreamArray<T>(View); Array = FixedStreamArray<T>(View);
Offset += Length; Offset += Length;
return Error::success(); return Error::success();
@ -84,7 +84,7 @@ public:
uint32_t bytesRemaining() const { return getLength() - getOffset(); } uint32_t bytesRemaining() const { return getLength() - getOffset(); }
private: private:
const StreamInterface &Stream; StreamRef Stream;
uint32_t Offset; uint32_t Offset;
}; };
} // namespace codeview } // namespace codeview

View File

@ -16,7 +16,7 @@
namespace llvm { namespace llvm {
namespace codeview { namespace codeview {
class StreamRef : public StreamInterface { class StreamRef : private StreamInterface {
public: public:
StreamRef() : Stream(nullptr), ViewOffset(0), Length(0) {} StreamRef() : Stream(nullptr), ViewOffset(0), Length(0) {}
StreamRef(const StreamInterface &Stream) StreamRef(const StreamInterface &Stream)
@ -50,6 +50,10 @@ public:
return StreamRef(*Stream, ViewOffset, N); return StreamRef(*Stream, ViewOffset, N);
} }
StreamRef slice(uint32_t Offset, uint32_t Len) const {
return drop_front(Offset).keep_front(Len);
}
bool operator==(const StreamRef &Other) const { bool operator==(const StreamRef &Other) const {
if (Stream != Other.Stream) if (Stream != Other.Stream)
return false; return false;

View File

@ -64,7 +64,7 @@ struct ModuleInfoEx {
namespace codeview { namespace codeview {
template <> struct VarStreamArrayExtractor<pdb::ModInfo> { template <> struct VarStreamArrayExtractor<pdb::ModInfo> {
Error operator()(const StreamInterface &Stream, uint32_t &Length, Error operator()(StreamRef Stream, uint32_t &Length,
pdb::ModInfo &Info) const { pdb::ModInfo &Info) const {
if (auto EC = pdb::ModInfo::initialize(Stream, Info)) if (auto EC = pdb::ModInfo::initialize(Stream, Info))
return EC; return EC;

View File

@ -15,7 +15,7 @@
using namespace llvm; using namespace llvm;
using namespace llvm::codeview; using namespace llvm::codeview;
StreamReader::StreamReader(const StreamInterface &S) : Stream(S), Offset(0) {} StreamReader::StreamReader(StreamRef Stream) : Stream(Stream), Offset(0) {}
Error StreamReader::readBytes(ArrayRef<uint8_t> &Buffer, uint32_t Size) { Error StreamReader::readBytes(ArrayRef<uint8_t> &Buffer, uint32_t Size) {
if (auto EC = Stream.readBytes(Offset, Size, Buffer)) if (auto EC = Stream.readBytes(Offset, Size, Buffer))
@ -80,7 +80,7 @@ Error StreamReader::readStreamRef(StreamRef &Ref) {
Error StreamReader::readStreamRef(StreamRef &Ref, uint32_t Length) { Error StreamReader::readStreamRef(StreamRef &Ref, uint32_t Length) {
if (bytesRemaining() < Length) if (bytesRemaining() < Length)
return make_error<CodeViewError>(cv_error_code::insufficient_buffer); return make_error<CodeViewError>(cv_error_code::insufficient_buffer);
Ref = StreamRef(Stream, Offset, Length); Ref = Stream.slice(Offset, Length);
Offset += Length; Offset += Length;
return Error::success(); return Error::success();
} }