ByteBuffer: fast path for contiguous collections

This commit is contained in:
Johannes Weiss 2017-11-03 02:39:54 -07:00
parent c22e5ead6b
commit 839a5050dd
6 changed files with 136 additions and 32 deletions

View File

@ -129,6 +129,13 @@ extension ByteBuffer {
return written
}
@discardableResult
public mutating func write<S: ContiguousCollection>(bytes: S) -> Int where S.Element == UInt8 {
let written = set(bytes: bytes, at: self.writerIndex)
self.moveWriterIndex(forwardBy: written)
return written
}
public func slice() -> ByteBuffer {
return slice(at: self.readerIndex, length: self.readableBytes)!
}

View File

@ -160,7 +160,24 @@ public struct ByteBuffer {
self._writerIndex = newIndex
}
private mutating func set<S: ContiguousCollection>(bytes: S, at index: Index) -> Capacity where S.Element == UInt8 {
let newEndIndex: Index = index + toIndex(Int(bytes.count))
if !isKnownUniquelyReferenced(&self._storage) {
let extraCapacity = newEndIndex > self._slice.upperBound ? newEndIndex - self._slice.upperBound : 0
self.copyStorageAndRebase(extraCapacity: extraCapacity)
}
self.ensureAvailableCapacity(Capacity(bytes.count), at: index)
let base = self._storage.bytes.advanced(by: Int(self._slice.lowerBound + index)).assumingMemoryBound(to: UInt8.self)
bytes.withUnsafeBytes { srcPtr in
base.assign(from: srcPtr.baseAddress!.assumingMemoryBound(to: S.Element.self), count: srcPtr.count)
}
return toCapacity(Int(bytes.count))
}
private mutating func set<S: Collection>(bytes: S, at index: Index) -> Capacity where S.Element == UInt8 {
assert(!([Array<S.Element>.self, StaticString.self, ContiguousArray<S.Element>.self, UnsafeRawBufferPointer.self, UnsafeBufferPointer<UInt8>.self].contains(where: { (t: Any.Type) -> Bool in t == type(of: bytes) })),
"called the slower set<S: Collection> function even though \(S.self) is a ContiguousCollection")
let newEndIndex: Index = index + toIndex(Int(bytes.count))
if !isKnownUniquelyReferenced(&self._storage) {
let extraCapacity = newEndIndex > self._slice.upperBound ? newEndIndex - self._slice.upperBound : 0
@ -295,6 +312,50 @@ extension ByteBuffer: CustomStringConvertible {
}
}
public protocol ContiguousCollection: Collection {
func withUnsafeBytes<R>(_ fn: (UnsafeRawBufferPointer) throws -> R) rethrows -> R
}
extension StaticString: Collection {
public func _customIndexOfEquatableElement(_ element: UInt8) -> Int?? {
return Int(element)
}
public typealias Element = UInt8
public typealias SubSequence = ArraySlice<UInt8>
public typealias Index = Int
public var startIndex: Index { return 0 }
public var endIndex: Index { return self.utf8CodeUnitCount }
public func index(after i: Index) -> Index { return i + 1 }
public func index(before i: Index) -> Index { return i - 1 }
public subscript(position: Int) -> StaticString.Element {
get {
return self[position]
}
}
}
extension Array: ContiguousCollection {}
extension ContiguousArray: ContiguousCollection {}
extension StaticString: ContiguousCollection {
public func withUnsafeBytes<R>(_ fn: (UnsafeRawBufferPointer) throws -> R) rethrows -> R {
return try fn(UnsafeRawBufferPointer(start: self.utf8Start, count: self.utf8CodeUnitCount))
}
}
extension UnsafeRawBufferPointer: ContiguousCollection {
public func withUnsafeBytes<R>(_ fn: (UnsafeRawBufferPointer) throws -> R) rethrows -> R {
return try fn(self)
}
}
extension UnsafeBufferPointer: ContiguousCollection {
public func withUnsafeBytes<R>(_ fn: (UnsafeRawBufferPointer) throws -> R) rethrows -> R {
return try fn(UnsafeRawBufferPointer(self))
}
}
/* change types to the user visible `Int` */
extension ByteBuffer {
@discardableResult
@ -302,6 +363,11 @@ extension ByteBuffer {
return Int(self.set(bytes: bytes, at: toIndex(index)))
}
@discardableResult
public mutating func set<S: ContiguousCollection>(bytes: S, at index: Int) -> Int where S.Element == UInt8 {
return Int(self.set(bytes: bytes, at: toIndex(index)))
}
public mutating func moveReaderIndex(forwardBy offset: Int) {
let newIndex = self._readerIndex + toIndex(offset)
precondition(newIndex >= 0 && newIndex <= writerIndex, "new readerIndex: \(newIndex), expected: range(0, \(writerIndex))")

View File

@ -14,6 +14,14 @@
import struct Foundation.Data
extension Data: ContiguousCollection {
public func withUnsafeBytes<R>(_ fn: (UnsafeRawBufferPointer) throws -> R) rethrows -> R {
return try self.withUnsafeBytes { (ptr: UnsafePointer<UInt8>) -> R in
return try fn(UnsafeRawBufferPointer(start: ptr, count: self.count))
}
}
}
extension ByteBuffer {
// MARK: Data APIs
@ -26,20 +34,6 @@ extension ByteBuffer {
return data
}
@discardableResult
public mutating func write(data: Data) -> Int {
let bytesWritten = self.set(data: data, at: self.writerIndex)
self.moveWriterIndex(forwardBy: bytesWritten)
return bytesWritten
}
@discardableResult
public mutating func set(data: Data, at index: Int) -> Int {
return data.withUnsafeBytes { ptr in
self.set(bytes: UnsafeRawBufferPointer(start: ptr, count: data.count), at: index)
}
}
public func data(at index: Int, length: Int) -> Data? {
guard index >= 0 && length >= 0 && index <= self.capacity - length else {
return nil

View File

@ -252,7 +252,7 @@ class SniHandlerTest: XCTestCase {
let data = Data(base64Encoded: string, options: .ignoreUnknownCharacters)!
let allocator = ByteBufferAllocator()
var buffer = allocator.buffer(capacity: data.count)
buffer.write(data: data)
buffer.write(bytes: data)
return buffer
}

View File

@ -93,6 +93,8 @@ extension ByteBufferTest {
("testReadWithUnsafeReadableBytesVariantsNothingToRead", testReadWithUnsafeReadableBytesVariantsNothingToRead),
("testReadWithUnsafeReadableBytesVariantsSomethingToRead", testReadWithUnsafeReadableBytesVariantsSomethingToRead),
("testSomePotentialIntegerUnderOrOverflows", testSomePotentialIntegerUnderOrOverflows),
("testWriteForContiguousCollections", testWriteForContiguousCollections),
("testWriteForNonContiguousCollections", testWriteForNonContiguousCollections),
]
}
}

View File

@ -318,7 +318,7 @@ class ByteBufferTest: XCTestCase {
var buffer = allocator.buffer(capacity: 32)
let data = Data(bytes: [1, 2, 3])
XCTAssertEqual(3, buffer.set(data: data, at: 0))
XCTAssertEqual(3, buffer.set(bytes: data, at: 0))
XCTAssertEqual(0, buffer.readableBytes)
XCTAssertEqual(data, buffer.data(at: 0, length: 3))
}
@ -327,7 +327,7 @@ class ByteBufferTest: XCTestCase {
var buffer = allocator.buffer(capacity: 32)
let data = Data(bytes: [1, 2, 3])
XCTAssertEqual(3, buffer.write(data: data))
XCTAssertEqual(3, buffer.write(bytes: data))
XCTAssertEqual(3, buffer.readableBytes)
XCTAssertEqual(data, buffer.readData(length: 3))
}
@ -356,7 +356,7 @@ class ByteBufferTest: XCTestCase {
func testDiscardReadBytesCoW() throws {
var buffer = allocator.buffer(capacity: 32)
let bytesWritten = buffer.write(data: "0123456789abcdef0123456789ABCDEF".data(using: .utf8)!)
let bytesWritten = buffer.write(bytes: "0123456789abcdef0123456789ABCDEF".data(using: .utf8)!)
XCTAssertEqual(32, bytesWritten)
func testAssumptionOriginalBuffer(_ buf: inout ByteBuffer) {
@ -484,11 +484,11 @@ class ByteBufferTest: XCTestCase {
func testExpansion() throws {
var buf = allocator.buffer(capacity: 16)
XCTAssertEqual(16, buf.capacity)
buf.write(data: "0123456789abcdef".data(using: .utf8)!)
buf.write(bytes: "0123456789abcdef".data(using: .utf8)!)
XCTAssertEqual(16, buf.capacity)
XCTAssertEqual(16, buf.writerIndex)
XCTAssertEqual(0, buf.readerIndex)
buf.write(data: "X".data(using: .utf8)!)
buf.write(bytes: "X".data(using: .utf8)!)
XCTAssertEqual(32, buf.capacity)
XCTAssertEqual(17, buf.writerIndex)
XCTAssertEqual(0, buf.readerIndex)
@ -502,7 +502,7 @@ class ByteBufferTest: XCTestCase {
func testExpansion2() throws {
var buf = allocator.buffer(capacity: 2)
XCTAssertEqual(2, buf.capacity)
buf.write(data: "0123456789abcdef".data(using: .utf8)!)
buf.write(bytes: "0123456789abcdef".data(using: .utf8)!)
XCTAssertEqual(16, buf.capacity)
XCTAssertEqual(16, buf.writerIndex)
buf.withUnsafeReadableBytes { ptr in
@ -515,7 +515,7 @@ class ByteBufferTest: XCTestCase {
func testNotEnoughBytesToReadForIntegers() throws {
let byteCount = 15
func initBuffer() {
let written = buf.write(data: Data(Array(repeating: 0, count: byteCount)))
let written = buf.write(bytes: Data(Array(repeating: 0, count: byteCount)))
XCTAssertEqual(byteCount, written)
}
@ -545,7 +545,7 @@ class ByteBufferTest: XCTestCase {
func testNotEnoughBytesToReadForData() throws {
let cap = buf.capacity
let expected = Data(Array(repeating: 0, count: cap))
let written = buf.write(data: expected)
let written = buf.write(bytes: expected)
XCTAssertEqual(cap, written)
XCTAssertEqual(cap, buf.capacity)
@ -610,8 +610,8 @@ class ByteBufferTest: XCTestCase {
var otherBuf = buf
otherBuf.set(data: Data(), at: 0)
buf.set(data: Data(), at: 0)
otherBuf.set(bytes: Data(), at: 0)
buf.set(bytes: Data(), at: 0)
XCTAssertEqual(0, buf.capacity)
XCTAssertEqual(0, otherBuf.capacity)
@ -630,10 +630,10 @@ class ByteBufferTest: XCTestCase {
func testReadDataNotEnoughAvailable() throws {
/* write some bytes */
buf.write(data: Data([0, 1, 2, 3]))
buf.write(bytes: Data([0, 1, 2, 3]))
/* make more available in the buffer that should not be readable */
buf.set(data: Data([4, 5, 6, 7]), at: 4)
buf.set(bytes: Data([4, 5, 6, 7]), at: 4)
let actualNil = buf.readData(length: 5)
XCTAssertNil(actualNil)
@ -647,10 +647,10 @@ class ByteBufferTest: XCTestCase {
func testReadSliceNotEnoughAvailable() throws {
/* write some bytes */
buf.write(data: Data([0, 1, 2, 3]))
buf.write(bytes: Data([0, 1, 2, 3]))
/* make more available in the buffer that should not be readable */
buf.set(data: Data([4, 5, 6, 7]), at: 4)
buf.set(bytes: Data([4, 5, 6, 7]), at: 4)
let actualNil = buf.readSlice(length: 5)
XCTAssertNil(actualNil)
@ -665,7 +665,7 @@ class ByteBufferTest: XCTestCase {
func testSetBuffer() throws {
var src = allocator.buffer(capacity: 4)
src.write(data: Data([0, 1, 2, 3]))
src.write(bytes: Data([0, 1, 2, 3]))
buf.set(buffer: src, at: 1)
@ -678,7 +678,7 @@ class ByteBufferTest: XCTestCase {
func testWriteBuffer() throws {
var src = allocator.buffer(capacity: 4)
src.write(data: Data([0, 1, 2, 3]))
src.write(bytes: Data([0, 1, 2, 3]))
buf.write(buffer: &src)
@ -691,7 +691,7 @@ class ByteBufferTest: XCTestCase {
func testMisalignedIntegerRead() throws {
let value = UInt64(7)
buf.write(data: Data([1]))
buf.write(bytes: Data([1]))
buf.write(integer: value)
let actual = buf.readData(length: 1)
XCTAssertEqual(Data([1]), actual)
@ -907,4 +907,39 @@ class ByteBufferTest: XCTestCase {
testIndexAndLengthFunc(buf.slice)
testIndexAndLengthFunc(buf.string)
}
func testWriteForContiguousCollections() throws {
buf.clear()
var written = buf.write(bytes: [1, 2, 3, 4])
XCTAssertEqual(4, written)
written += [5 as UInt8, 6, 7, 8].withUnsafeBytes { ptr in
buf.write(bytes: ptr)
}
XCTAssertEqual(8, written)
written += [9 as UInt8, 10, 11, 12].withUnsafeBufferPointer { ptr in
buf.write(bytes: ptr)
}
XCTAssertEqual(12, written)
written += buf.write(bytes: ContiguousArray<UInt8>([13, 14, 15, 16]))
XCTAssertEqual(16, written)
written += buf.write(bytes: "ABCD" as StaticString)
XCTAssertEqual(20, written)
written += buf.write(bytes: "EFGH".data(using: .utf8)!)
XCTAssertEqual(24, written)
let expected = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, "A".utf8.first!, "B".utf8.first!, "C".utf8.first!, "D".utf8.first!, "E".utf8.first!, "F".utf8.first!, "G".utf8.first!, "H".utf8.first!]
XCTAssertEqual(expected, buf.readBytes(length: written)!)
}
func testWriteForNonContiguousCollections() throws {
buf.clear()
let written = buf.write(bytes: "ABCD".utf8)
XCTAssertEqual(4, written)
let expected = ["A".utf8.first!, "B".utf8.first!, "C".utf8.first!, "D".utf8.first!]
XCTAssertEqual(expected, buf.readBytes(length: written)!)
}
}