Add internal support for subjectAltName

This commit is contained in:
Cory Benfield 2017-11-01 17:46:56 +00:00
parent cffd11d4e8
commit 4e04a275c0
4 changed files with 170 additions and 0 deletions

View File

@ -34,4 +34,20 @@ static inline void CNIOOpenSSL_SSL_CTX_setAutoECDH(SSL_CTX *ctx) {
#endif
}
static inline int CNIOOpenSSL_sk_GENERAL_NAME_num(STACK_OF(GENERAL_NAME) *x) {
return sk_GENERAL_NAME_num(x);
}
static inline const GENERAL_NAME *CNIOOpenSSL_sk_GENERAL_NAME_value(STACK_OF(GENERAL_NAME) *x, int idx) {
return sk_GENERAL_NAME_value(x, idx);
}
static inline const unsigned char *CNIOOpenSSL_ASN1_STRING_get0_data(ASN1_STRING *x) {
#if (OPENSSL_VERSION_NUMBER < 0x10100000L) || defined(LIBRESSL_VERSION_NUMBER)
return ASN1_STRING_data(x);
#else
return ASN1_STRING_get0_data(x);
#endif
}
#endif

View File

@ -18,10 +18,21 @@
import Glibc
#endif
import CNIOOpenSSL
import NIO
public class OpenSSLCertificate {
internal let ref: UnsafeMutablePointer<X509>
internal enum AlternativeName {
case dnsName(String)
case ipAddress(IPAddress)
}
internal enum IPAddress {
case ipv4(in_addr)
case ipv6(in6_addr)
}
private init(withReference ref: UnsafeMutablePointer<X509>) {
self.ref = ref
}
@ -90,6 +101,15 @@ public class OpenSSLCertificate {
return OpenSSLCertificate(withReference: UnsafeMutablePointer(mutating: pointer))
}
/// Get a sequence of the alternative names in the certificate.
internal func subjectAlternativeNames() -> SubjectAltNameSequence? {
guard let sanExtension = X509_get_ext_d2i(ref, NID_subject_alt_name, nil, nil) else {
return nil
}
let sanNames = sanExtension.assumingMemoryBound(to: stack_st_GENERAL_NAME.self)
return SubjectAltNameSequence(nameStack: sanNames)
}
deinit {
X509_free(ref)
}
@ -100,3 +120,78 @@ extension OpenSSLCertificate: Equatable {
return X509_cmp(lhs.ref, rhs.ref) == 0
}
}
internal class SubjectAltNameSequence: Sequence, IteratorProtocol {
typealias Element = OpenSSLCertificate.AlternativeName
private let nameStack: UnsafeMutablePointer<stack_st_GENERAL_NAME>
private var nextIdx: Int32
private let stackSize: Int32
init(nameStack: UnsafeMutablePointer<stack_st_GENERAL_NAME>) {
self.nameStack = nameStack
self.stackSize = CNIOOpenSSL_sk_GENERAL_NAME_num(nameStack)
self.nextIdx = 0
}
private func addressFromBytes(bytes: UnsafeBufferPointer<UInt8>) -> OpenSSLCertificate.IPAddress? {
switch bytes.count {
case 4:
let addr = bytes.baseAddress?.withMemoryRebound(to: in_addr.self, capacity: 1) {
return $0.pointee
}
guard let innerAddr = addr else {
return nil
}
return .ipv4(innerAddr)
case 16:
let addr = bytes.baseAddress?.withMemoryRebound(to: in6_addr.self, capacity: 1) {
return $0.pointee
}
guard let innerAddr = addr else {
return nil
}
return .ipv6(innerAddr)
default:
return nil
}
}
func next() -> OpenSSLCertificate.AlternativeName? {
guard nextIdx < stackSize else {
return nil
}
guard let name = CNIOOpenSSL_sk_GENERAL_NAME_value(nameStack, nextIdx) else {
fatalError("Unexpected null pointer when unwrapping SAN value")
}
nextIdx += 1
switch name.pointee.type {
case GEN_DNS:
let namePtr = UnsafeBufferPointer(start: CNIOOpenSSL_ASN1_STRING_get0_data(name.pointee.d.ia5),
count: Int(ASN1_STRING_length(name.pointee.d.ia5)))
guard let nameString = String(bytes: namePtr, encoding: .ascii) else {
// This should throw, but we can't throw from next(). Skip this instead.
return next()
}
return .dnsName(nameString)
case GEN_IPADD:
let addrPtr = UnsafeBufferPointer(start: CNIOOpenSSL_ASN1_STRING_get0_data(name.pointee.d.ia5),
count: Int(ASN1_STRING_length(name.pointee.d.ia5)))
guard let addr = addressFromBytes(bytes: addrPtr) else {
// This should throw, but we can't throw from next(). Skip this instead.
return next()
}
return .ipAddress(addr)
default:
// We don't recognise this name type. Skip it.
return next()
}
}
deinit {
GENERAL_NAMES_free(nameStack)
}
}

View File

@ -35,6 +35,8 @@ extension SSLCertificateTest {
("testLoadingGibberishFromMemoryAsDerFails", testLoadingGibberishFromMemoryAsDerFails),
("testLoadingGibberishFromFileAsPemFails", testLoadingGibberishFromFileAsPemFails),
("testLoadingGibberishFromFileAsDerFails", testLoadingGibberishFromFileAsDerFails),
("testEnumeratingSanFields", testEnumeratingSanFields),
("testNonexistentSan", testNonexistentSan),
]
}
}

View File

@ -17,6 +17,28 @@ import XCTest
@testable import NIO
@testable import NIOOpenSSL
let multiSanCert = """
-----BEGIN CERTIFICATE-----
MIIDEzCCAfugAwIBAgIURiMaUmhI1Xr0mZ4p+JmI0XjZTaIwDQYJKoZIhvcNAQEL
BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTE3MTAzMDEyMDUwMFoXDTQwMDEw
MTAwMDAwMFowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
AAOCAQ8AMIIBCgKCAQEA26DcKAxqdWivhS/J3Klf+cEnrT2cDzLhmVRCHuQZXiIr
tqr5401KDbRTVOg8v2qIyd8x4+YbpE47JP3fBrcMey70UK/Er8nu28RY3z7gZLLi
Yf+obHdDFCK5JaCGmM61I0c0vp7aMXsyv7h3vjEzTuBMlKR8p37ftaXSUAe3Qk/D
/fzA3k02E2e3ap0Sapd/wUu/0n/MFyy9HkkeykivAzLaaFhhvp3hATdFYC4FLld8
OMB60bC2S13CAljpMlpjU/XLLOUbaPgnNUqE1nFqFBoTl6kV6+ii8Dd5ENVvE7pE
SoNoyGLDUkDRJJMNUHAo0zbxyhd7WOtyZ7B4YBbPswIDAQABo10wWzBLBgNVHREE
RDBCgglsb2NhbGhvc3SCC2V4YW1wbGUuY29tgRB1c2VyQGV4YW1wbGUuY29thwTA
qAABhxAgAQ24AAAAAAAAAAAAAAABMAwGA1UdEwEB/wQCMAAwDQYJKoZIhvcNAQEL
BQADggEBACYBArIoL9ZzVX3M+WmTD5epmGEffrH7diRJZsfpVXi86brBPrbvpTBx
Fa+ZKxBAchPnWn4rxoWVJmTm4WYqZljek7oQKzidu88rMTbsxHA+/qyVPVlQ898I
hgnW4h3FFapKOFqq5Hj2gKKItFIcGoVY2oLTBFkyfAx0ofromGQp3fh58KlPhC0W
GX1nFCea74mGyq60X86aEWiyecYYj5AEcaDrTnGg3HLGTsD3mh8SUZPAda13rO4+
RGtGsA1C9Yovlu9a6pWLgephYJ73XYPmRIGgM64fkUbSuvXNJMYbWnzpoCdW6hka
IEaDUul/WnIkn/JZx8n+wgoWtyQa4EA=
-----END CERTIFICATE-----
"""
private func makeTemporaryFile() -> String {
let template = "/tmp/niotestXXXXXXX"
var templateBytes = Array(template.utf8)
@ -160,4 +182,39 @@ class SSLCertificateTest: XCTestCase {
// Do nothing.
}
}
func testEnumeratingSanFields() throws {
var v4addr = in_addr()
var v6addr = in6_addr()
precondition(inet_pton(AF_INET, "192.168.0.1", &v4addr) == 1)
precondition(inet_pton(AF_INET6, "2001:db8::1", &v6addr) == 1)
let expectedSanFields: [OpenSSLCertificate.AlternativeName] = [
.dnsName("localhost"),
.dnsName("example.com"),
.ipAddress(.ipv4(v4addr)),
.ipAddress(.ipv6(v6addr)),
]
let cert = try OpenSSLCertificate(buffer: [Int8](multiSanCert.utf8CString), format: .pem)
let sans = [OpenSSLCertificate.AlternativeName](cert.subjectAlternativeNames()!)
XCTAssertEqual(sans.count, expectedSanFields.count)
for index in 0..<sans.count {
switch (sans[index], expectedSanFields[index]) {
case (.dnsName(let actualName), .dnsName(let expectedName)):
XCTAssertEqual(actualName, expectedName)
case (.ipAddress(.ipv4(var actualAddr)), .ipAddress(.ipv4(var expectedAddr))):
XCTAssertEqual(memcmp(&actualAddr, &expectedAddr, MemoryLayout<in_addr>.size), 0)
case (.ipAddress(.ipv6(var actualAddr)), .ipAddress(.ipv6(var expectedAddr))):
XCTAssertEqual(memcmp(&actualAddr, &expectedAddr, MemoryLayout<in6_addr>.size), 0)
default:
XCTFail("Invalid entry in sans.")
}
}
}
func testNonexistentSan() throws {
let cert = try OpenSSLCertificate(buffer: [Int8](samplePemCert.utf8CString), format: .pem)
XCTAssertNil(cert.subjectAlternativeNames())
}
}