From 839a5050ddfcaf9ab5e18b840463cb533d1723a9 Mon Sep 17 00:00:00 2001 From: Johannes Weiss Date: Fri, 3 Nov 2017 02:39:54 -0700 Subject: [PATCH] ByteBuffer: fast path for contiguous collections --- Sources/NIO/ByteBuffer-aux.swift | 7 +++ Sources/NIO/ByteBuffer-core.swift | 66 +++++++++++++++++++++ Sources/NIO/ByteBuffer-foundation.swift | 22 +++---- Tests/NIOTLSTests/SNIHandlerTests.swift | 2 +- Tests/NIOTests/ByteBufferTest+XCTest.swift | 2 + Tests/NIOTests/ByteBufferTest.swift | 69 ++++++++++++++++------ 6 files changed, 136 insertions(+), 32 deletions(-) diff --git a/Sources/NIO/ByteBuffer-aux.swift b/Sources/NIO/ByteBuffer-aux.swift index 26b84ab0..3e29d91f 100644 --- a/Sources/NIO/ByteBuffer-aux.swift +++ b/Sources/NIO/ByteBuffer-aux.swift @@ -129,6 +129,13 @@ extension ByteBuffer { return written } + @discardableResult + public mutating func write(bytes: S) -> Int where S.Element == UInt8 { + let written = set(bytes: bytes, at: self.writerIndex) + self.moveWriterIndex(forwardBy: written) + return written + } + public func slice() -> ByteBuffer { return slice(at: self.readerIndex, length: self.readableBytes)! } diff --git a/Sources/NIO/ByteBuffer-core.swift b/Sources/NIO/ByteBuffer-core.swift index 16b8d1dc..724637de 100644 --- a/Sources/NIO/ByteBuffer-core.swift +++ b/Sources/NIO/ByteBuffer-core.swift @@ -160,7 +160,24 @@ public struct ByteBuffer { self._writerIndex = newIndex } + private mutating func set(bytes: S, at index: Index) -> Capacity where S.Element == UInt8 { + let newEndIndex: Index = index + toIndex(Int(bytes.count)) + if !isKnownUniquelyReferenced(&self._storage) { + let extraCapacity = newEndIndex > self._slice.upperBound ? newEndIndex - self._slice.upperBound : 0 + self.copyStorageAndRebase(extraCapacity: extraCapacity) + } + + self.ensureAvailableCapacity(Capacity(bytes.count), at: index) + let base = self._storage.bytes.advanced(by: Int(self._slice.lowerBound + index)).assumingMemoryBound(to: UInt8.self) + bytes.withUnsafeBytes { srcPtr in + base.assign(from: srcPtr.baseAddress!.assumingMemoryBound(to: S.Element.self), count: srcPtr.count) + } + return toCapacity(Int(bytes.count)) + } + private mutating func set(bytes: S, at index: Index) -> Capacity where S.Element == UInt8 { + assert(!([Array.self, StaticString.self, ContiguousArray.self, UnsafeRawBufferPointer.self, UnsafeBufferPointer.self].contains(where: { (t: Any.Type) -> Bool in t == type(of: bytes) })), + "called the slower set function even though \(S.self) is a ContiguousCollection") let newEndIndex: Index = index + toIndex(Int(bytes.count)) if !isKnownUniquelyReferenced(&self._storage) { let extraCapacity = newEndIndex > self._slice.upperBound ? newEndIndex - self._slice.upperBound : 0 @@ -295,6 +312,50 @@ extension ByteBuffer: CustomStringConvertible { } } +public protocol ContiguousCollection: Collection { + func withUnsafeBytes(_ fn: (UnsafeRawBufferPointer) throws -> R) rethrows -> R +} + +extension StaticString: Collection { + public func _customIndexOfEquatableElement(_ element: UInt8) -> Int?? { + return Int(element) + } + + public typealias Element = UInt8 + public typealias SubSequence = ArraySlice + + public typealias Index = Int + + public var startIndex: Index { return 0 } + public var endIndex: Index { return self.utf8CodeUnitCount } + public func index(after i: Index) -> Index { return i + 1 } + public func index(before i: Index) -> Index { return i - 1 } + + public subscript(position: Int) -> StaticString.Element { + get { + return self[position] + } + } +} + +extension Array: ContiguousCollection {} +extension ContiguousArray: ContiguousCollection {} +extension StaticString: ContiguousCollection { + public func withUnsafeBytes(_ fn: (UnsafeRawBufferPointer) throws -> R) rethrows -> R { + return try fn(UnsafeRawBufferPointer(start: self.utf8Start, count: self.utf8CodeUnitCount)) + } +} +extension UnsafeRawBufferPointer: ContiguousCollection { + public func withUnsafeBytes(_ fn: (UnsafeRawBufferPointer) throws -> R) rethrows -> R { + return try fn(self) + } +} +extension UnsafeBufferPointer: ContiguousCollection { + public func withUnsafeBytes(_ fn: (UnsafeRawBufferPointer) throws -> R) rethrows -> R { + return try fn(UnsafeRawBufferPointer(self)) + } +} + /* change types to the user visible `Int` */ extension ByteBuffer { @discardableResult @@ -302,6 +363,11 @@ extension ByteBuffer { return Int(self.set(bytes: bytes, at: toIndex(index))) } + @discardableResult + public mutating func set(bytes: S, at index: Int) -> Int where S.Element == UInt8 { + return Int(self.set(bytes: bytes, at: toIndex(index))) + } + public mutating func moveReaderIndex(forwardBy offset: Int) { let newIndex = self._readerIndex + toIndex(offset) precondition(newIndex >= 0 && newIndex <= writerIndex, "new readerIndex: \(newIndex), expected: range(0, \(writerIndex))") diff --git a/Sources/NIO/ByteBuffer-foundation.swift b/Sources/NIO/ByteBuffer-foundation.swift index f7dd00af..ff835053 100644 --- a/Sources/NIO/ByteBuffer-foundation.swift +++ b/Sources/NIO/ByteBuffer-foundation.swift @@ -14,6 +14,14 @@ import struct Foundation.Data +extension Data: ContiguousCollection { + public func withUnsafeBytes(_ fn: (UnsafeRawBufferPointer) throws -> R) rethrows -> R { + return try self.withUnsafeBytes { (ptr: UnsafePointer) -> R in + return try fn(UnsafeRawBufferPointer(start: ptr, count: self.count)) + } + } +} + extension ByteBuffer { // MARK: Data APIs @@ -26,20 +34,6 @@ extension ByteBuffer { return data } - @discardableResult - public mutating func write(data: Data) -> Int { - let bytesWritten = self.set(data: data, at: self.writerIndex) - self.moveWriterIndex(forwardBy: bytesWritten) - return bytesWritten - } - - @discardableResult - public mutating func set(data: Data, at index: Int) -> Int { - return data.withUnsafeBytes { ptr in - self.set(bytes: UnsafeRawBufferPointer(start: ptr, count: data.count), at: index) - } - } - public func data(at index: Int, length: Int) -> Data? { guard index >= 0 && length >= 0 && index <= self.capacity - length else { return nil diff --git a/Tests/NIOTLSTests/SNIHandlerTests.swift b/Tests/NIOTLSTests/SNIHandlerTests.swift index 3d174465..0dd2cdfd 100644 --- a/Tests/NIOTLSTests/SNIHandlerTests.swift +++ b/Tests/NIOTLSTests/SNIHandlerTests.swift @@ -252,7 +252,7 @@ class SniHandlerTest: XCTestCase { let data = Data(base64Encoded: string, options: .ignoreUnknownCharacters)! let allocator = ByteBufferAllocator() var buffer = allocator.buffer(capacity: data.count) - buffer.write(data: data) + buffer.write(bytes: data) return buffer } diff --git a/Tests/NIOTests/ByteBufferTest+XCTest.swift b/Tests/NIOTests/ByteBufferTest+XCTest.swift index 245f24c8..16b87e9c 100644 --- a/Tests/NIOTests/ByteBufferTest+XCTest.swift +++ b/Tests/NIOTests/ByteBufferTest+XCTest.swift @@ -93,6 +93,8 @@ extension ByteBufferTest { ("testReadWithUnsafeReadableBytesVariantsNothingToRead", testReadWithUnsafeReadableBytesVariantsNothingToRead), ("testReadWithUnsafeReadableBytesVariantsSomethingToRead", testReadWithUnsafeReadableBytesVariantsSomethingToRead), ("testSomePotentialIntegerUnderOrOverflows", testSomePotentialIntegerUnderOrOverflows), + ("testWriteForContiguousCollections", testWriteForContiguousCollections), + ("testWriteForNonContiguousCollections", testWriteForNonContiguousCollections), ] } } diff --git a/Tests/NIOTests/ByteBufferTest.swift b/Tests/NIOTests/ByteBufferTest.swift index 887aab38..dbb73adc 100644 --- a/Tests/NIOTests/ByteBufferTest.swift +++ b/Tests/NIOTests/ByteBufferTest.swift @@ -318,7 +318,7 @@ class ByteBufferTest: XCTestCase { var buffer = allocator.buffer(capacity: 32) let data = Data(bytes: [1, 2, 3]) - XCTAssertEqual(3, buffer.set(data: data, at: 0)) + XCTAssertEqual(3, buffer.set(bytes: data, at: 0)) XCTAssertEqual(0, buffer.readableBytes) XCTAssertEqual(data, buffer.data(at: 0, length: 3)) } @@ -327,7 +327,7 @@ class ByteBufferTest: XCTestCase { var buffer = allocator.buffer(capacity: 32) let data = Data(bytes: [1, 2, 3]) - XCTAssertEqual(3, buffer.write(data: data)) + XCTAssertEqual(3, buffer.write(bytes: data)) XCTAssertEqual(3, buffer.readableBytes) XCTAssertEqual(data, buffer.readData(length: 3)) } @@ -356,7 +356,7 @@ class ByteBufferTest: XCTestCase { func testDiscardReadBytesCoW() throws { var buffer = allocator.buffer(capacity: 32) - let bytesWritten = buffer.write(data: "0123456789abcdef0123456789ABCDEF".data(using: .utf8)!) + let bytesWritten = buffer.write(bytes: "0123456789abcdef0123456789ABCDEF".data(using: .utf8)!) XCTAssertEqual(32, bytesWritten) func testAssumptionOriginalBuffer(_ buf: inout ByteBuffer) { @@ -484,11 +484,11 @@ class ByteBufferTest: XCTestCase { func testExpansion() throws { var buf = allocator.buffer(capacity: 16) XCTAssertEqual(16, buf.capacity) - buf.write(data: "0123456789abcdef".data(using: .utf8)!) + buf.write(bytes: "0123456789abcdef".data(using: .utf8)!) XCTAssertEqual(16, buf.capacity) XCTAssertEqual(16, buf.writerIndex) XCTAssertEqual(0, buf.readerIndex) - buf.write(data: "X".data(using: .utf8)!) + buf.write(bytes: "X".data(using: .utf8)!) XCTAssertEqual(32, buf.capacity) XCTAssertEqual(17, buf.writerIndex) XCTAssertEqual(0, buf.readerIndex) @@ -502,7 +502,7 @@ class ByteBufferTest: XCTestCase { func testExpansion2() throws { var buf = allocator.buffer(capacity: 2) XCTAssertEqual(2, buf.capacity) - buf.write(data: "0123456789abcdef".data(using: .utf8)!) + buf.write(bytes: "0123456789abcdef".data(using: .utf8)!) XCTAssertEqual(16, buf.capacity) XCTAssertEqual(16, buf.writerIndex) buf.withUnsafeReadableBytes { ptr in @@ -515,7 +515,7 @@ class ByteBufferTest: XCTestCase { func testNotEnoughBytesToReadForIntegers() throws { let byteCount = 15 func initBuffer() { - let written = buf.write(data: Data(Array(repeating: 0, count: byteCount))) + let written = buf.write(bytes: Data(Array(repeating: 0, count: byteCount))) XCTAssertEqual(byteCount, written) } @@ -545,7 +545,7 @@ class ByteBufferTest: XCTestCase { func testNotEnoughBytesToReadForData() throws { let cap = buf.capacity let expected = Data(Array(repeating: 0, count: cap)) - let written = buf.write(data: expected) + let written = buf.write(bytes: expected) XCTAssertEqual(cap, written) XCTAssertEqual(cap, buf.capacity) @@ -610,8 +610,8 @@ class ByteBufferTest: XCTestCase { var otherBuf = buf - otherBuf.set(data: Data(), at: 0) - buf.set(data: Data(), at: 0) + otherBuf.set(bytes: Data(), at: 0) + buf.set(bytes: Data(), at: 0) XCTAssertEqual(0, buf.capacity) XCTAssertEqual(0, otherBuf.capacity) @@ -630,10 +630,10 @@ class ByteBufferTest: XCTestCase { func testReadDataNotEnoughAvailable() throws { /* write some bytes */ - buf.write(data: Data([0, 1, 2, 3])) + buf.write(bytes: Data([0, 1, 2, 3])) /* make more available in the buffer that should not be readable */ - buf.set(data: Data([4, 5, 6, 7]), at: 4) + buf.set(bytes: Data([4, 5, 6, 7]), at: 4) let actualNil = buf.readData(length: 5) XCTAssertNil(actualNil) @@ -647,10 +647,10 @@ class ByteBufferTest: XCTestCase { func testReadSliceNotEnoughAvailable() throws { /* write some bytes */ - buf.write(data: Data([0, 1, 2, 3])) + buf.write(bytes: Data([0, 1, 2, 3])) /* make more available in the buffer that should not be readable */ - buf.set(data: Data([4, 5, 6, 7]), at: 4) + buf.set(bytes: Data([4, 5, 6, 7]), at: 4) let actualNil = buf.readSlice(length: 5) XCTAssertNil(actualNil) @@ -665,7 +665,7 @@ class ByteBufferTest: XCTestCase { func testSetBuffer() throws { var src = allocator.buffer(capacity: 4) - src.write(data: Data([0, 1, 2, 3])) + src.write(bytes: Data([0, 1, 2, 3])) buf.set(buffer: src, at: 1) @@ -678,7 +678,7 @@ class ByteBufferTest: XCTestCase { func testWriteBuffer() throws { var src = allocator.buffer(capacity: 4) - src.write(data: Data([0, 1, 2, 3])) + src.write(bytes: Data([0, 1, 2, 3])) buf.write(buffer: &src) @@ -691,7 +691,7 @@ class ByteBufferTest: XCTestCase { func testMisalignedIntegerRead() throws { let value = UInt64(7) - buf.write(data: Data([1])) + buf.write(bytes: Data([1])) buf.write(integer: value) let actual = buf.readData(length: 1) XCTAssertEqual(Data([1]), actual) @@ -907,4 +907,39 @@ class ByteBufferTest: XCTestCase { testIndexAndLengthFunc(buf.slice) testIndexAndLengthFunc(buf.string) } + + func testWriteForContiguousCollections() throws { + buf.clear() + var written = buf.write(bytes: [1, 2, 3, 4]) + XCTAssertEqual(4, written) + written += [5 as UInt8, 6, 7, 8].withUnsafeBytes { ptr in + buf.write(bytes: ptr) + } + XCTAssertEqual(8, written) + written += [9 as UInt8, 10, 11, 12].withUnsafeBufferPointer { ptr in + buf.write(bytes: ptr) + } + XCTAssertEqual(12, written) + written += buf.write(bytes: ContiguousArray([13, 14, 15, 16])) + XCTAssertEqual(16, written) + written += buf.write(bytes: "ABCD" as StaticString) + XCTAssertEqual(20, written) + written += buf.write(bytes: "EFGH".data(using: .utf8)!) + XCTAssertEqual(24, written) + + let expected = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, "A".utf8.first!, "B".utf8.first!, "C".utf8.first!, "D".utf8.first!, "E".utf8.first!, "F".utf8.first!, "G".utf8.first!, "H".utf8.first!] + + XCTAssertEqual(expected, buf.readBytes(length: written)!) + } + + func testWriteForNonContiguousCollections() throws { + buf.clear() + let written = buf.write(bytes: "ABCD".utf8) + XCTAssertEqual(4, written) + + let expected = ["A".utf8.first!, "B".utf8.first!, "C".utf8.first!, "D".utf8.first!] + + XCTAssertEqual(expected, buf.readBytes(length: written)!) + } + }