From 47ac46bbe3544ec2df668beed2d0169e09b2aae3 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Mon, 9 Oct 2017 10:37:13 -0700 Subject: [PATCH] Don't join Set-Cookie headers together --- Sources/NIOHTTP1/HTTPTypes.swift | 41 ++++++++++++++----- .../HTTPHeadersTest+XCTest.swift | 1 + Tests/NIOHTTP1Tests/HTTPHeadersTest.swift | 22 ++++++++++ 3 files changed, 53 insertions(+), 11 deletions(-) diff --git a/Sources/NIOHTTP1/HTTPTypes.swift b/Sources/NIOHTTP1/HTTPTypes.swift index 1200d6ac..e8e20a1b 100644 --- a/Sources/NIOHTTP1/HTTPTypes.swift +++ b/Sources/NIOHTTP1/HTTPTypes.swift @@ -149,22 +149,41 @@ public struct HTTPHeaders : Sequence, CustomStringConvertible { func write(buffer: inout ByteBuffer) { for (key, values) in storage { - buffer.write(string: key) - buffer.write(staticString: headerSeparator) - - var writerIndex = buffer.writerIndex - for (_, value) in values { - buffer.write(string: value) - writerIndex = buffer.writerIndex - buffer.write(staticString: ",") + if key != "set-cookie" { + writeListHeaderValues(buffer: &buffer, key: key, values: values) + } else { + writeSequentialHeaderValues(buffer: &buffer, key: key, values: values) } - // Discard last , - buffer.moveWriterIndex(to: writerIndex) - buffer.write(staticString: crlf) } buffer.write(staticString: crlf) } + /// Used for most HTTP headers, which can be represented as a single line joined by commas. + private func writeListHeaderValues(buffer: inout ByteBuffer, key: String, values: [(String, String)]) { + buffer.write(string: key) + buffer.write(staticString: headerSeparator) + + var writerIndex = buffer.writerIndex + for (_, value) in values { + buffer.write(string: value) + writerIndex = buffer.writerIndex + buffer.write(staticString: ",") + } + // Discard last , + buffer.moveWriterIndex(to: writerIndex) + buffer.write(staticString: crlf) + } + + /// Used for HTTP headers that cannot be joined with commas, e.g. set-cookie. + private func writeSequentialHeaderValues(buffer: inout ByteBuffer, key: String, values: [(String, String)]) { + for (_, value) in values { + buffer.write(string: key) + buffer.write(staticString: headerSeparator) + buffer.write(string: value) + buffer.write(staticString: crlf) + } + } + public func makeIterator() -> AnyIterator<(name: String, value: String)> { return AnyIterator(HTTPHeadersIterator(wrapping: storage.makeIterator())) } diff --git a/Tests/NIOHTTP1Tests/HTTPHeadersTest+XCTest.swift b/Tests/NIOHTTP1Tests/HTTPHeadersTest+XCTest.swift index 58d95461..e28e6f16 100644 --- a/Tests/NIOHTTP1Tests/HTTPHeadersTest+XCTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPHeadersTest+XCTest.swift @@ -27,6 +27,7 @@ extension HTTPHeadersTest { static var allTests : [(String, (HTTPHeadersTest) -> () throws -> Void)] { return [ ("testCasePreservedButInsensitiveLookup", testCasePreservedButInsensitiveLookup), + ("testWriteHeadersSeparately", testWriteHeadersSeparately), ] } } diff --git a/Tests/NIOHTTP1Tests/HTTPHeadersTest.swift b/Tests/NIOHTTP1Tests/HTTPHeadersTest.swift index 813275dc..2b2b4173 100644 --- a/Tests/NIOHTTP1Tests/HTTPHeadersTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPHeadersTest.swift @@ -14,6 +14,7 @@ import Foundation import XCTest +@testable import NIO @testable import NIOHTTP1 class HTTPHeadersTest : XCTestCase { @@ -49,4 +50,25 @@ class HTTPHeadersTest : XCTestCase { } } } + + func testWriteHeadersSeparately() { + let originalHeaders = [ ("User-Agent", "1"), + ("host", "2"), + ("X-SOMETHING", "3"), + ("X-Something", "4"), + ("SET-COOKIE", "foo=bar"), + ("Set-Cookie", "buz=cux")] + + let headers = HTTPHeaders(originalHeaders) + let channel = EmbeddedChannel() + var buffer = channel.allocator.buffer(capacity: 1024) + headers.write(buffer: &buffer) + + let writtenBytes = buffer.string(at: buffer.readerIndex, length: buffer.readableBytes)! + XCTAssertTrue(writtenBytes.contains("user-agent: 1\r\n")) + XCTAssertTrue(writtenBytes.contains("host: 2\r\n")) + XCTAssertTrue(writtenBytes.contains("x-something: 3,4\r\n")) + XCTAssertTrue(writtenBytes.contains("set-cookie: foo=bar\r\n")) + XCTAssertTrue(writtenBytes.contains("set-cookie: buz=cux\r\n")) + } }