IDPROD-2482 part 1: Utils needed for ML IDDetector (#693)

Adds utilities needed by the IDDetector ML model including:
- CGRect helpers
- MLMultiArray helper
- A non-maximum suppression algorithm copied from https://github.com/hollance/CoreMLHelpers
This commit is contained in:
Mel 2022-02-03 17:19:23 -08:00 committed by GitHub
parent 327dae826b
commit 054faf5f88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 329 additions and 1 deletions

View File

@ -9,6 +9,10 @@ import Foundation
import CoreGraphics
@_spi(STP) public extension CGRect {
/// Represents the bounds of a normalized coordinate system with range from (0,0) to (1,1)
static let normalizedBounds = CGRect(x: 0, y: 0, width: 1, height: 1)
/**
- Returns: A `CGRect` that has its y-coordinates inverted between the
upper-left corner and lower-left corner.
@ -26,4 +30,54 @@ import CoreGraphics
height: height
)
}
/**
Converts a rectangle that's using a normalized coordinate system from a
center-crop coordinate system to an un-cropped coordinate system
Example, if the original size has a portrait aspect ratio, center-cropping
the rect will result in the square area:
```
+---------+
| |
|---------|
| |
| |
| |
|---------|
| |
+---------+
```
This method converts the rect's coordinate relative to the center-cropped
area into coordinates relative to the original un-cropped area:
```
+---------+
| |
+---------+ | |
| +--+ | | +--+ |
| | | | --> | | | |
| +--+ | | +--+ |
+---------+ | |
| |
+---------+
```
- Parameters:
- size: The original size of the un-cropped area.
*/
func convertFromNormalizedCenterCropSquare(
toOriginalSize originalSize: CGSize
) -> CGRect {
let croppedWidth = min(originalSize.width, originalSize.height)
let scaleX = croppedWidth / originalSize.width
let scaleY = croppedWidth / originalSize.height
return CGRect(
x: (minX - 0.5) * scaleX + 0.5,
y: (minY - 0.5) * scaleY + 0.5,
width: width * scaleX,
height: height * scaleY
)
}
}

View File

@ -19,4 +19,51 @@ class CGRect_StripeCameraCoreTest: XCTestCase {
XCTAssertEqual(invertedRect.width, 0.5)
XCTAssertEqual(invertedRect.height, 0.5)
}
func testConvertFromNormalizedCenterCropSquareLandscape() {
// Centered-square rect corresponds to (350,0) 900x900
let originalSize = CGSize(width: 1600, height: 900)
// Corresponds to actual coordinates of
// (350+900*0.25,900*0.25) 900*0.25 x 900*0.25
// = (575,225) 225x225
let normalizedSquareRect = CGRect(x: 0.25, y: 0.25, width: 0.25, height: 0.25)
// Should have coordinates of
// (575/1600, 225/900) 225/1600 x 225/900
// = (0.359375,0.25) 0.140625 x 0.25
let convertedRect = normalizedSquareRect.convertFromNormalizedCenterCropSquare(toOriginalSize: originalSize)
XCTAssertEqual(convertedRect.minX, 0.359375)
XCTAssertEqual(convertedRect.minY, 0.25)
XCTAssertEqual(convertedRect.width, 0.140625)
XCTAssertEqual(convertedRect.height, 0.25)
}
func testConvertFromNormalizedCenterCropSquarePortrait() {
// Centered-square rect corresponds to (0,350) 900x900
let originalSize = CGSize(width: 900, height: 1600)
// Corresponds to actual coordinates of
// (900*0.25,350+900*0.25) 900*0.25 x 900*0.25
// = (225,575) 225x225
let normalizedSquareRect = CGRect(x: 0.25, y: 0.25, width: 0.25, height: 0.25)
// Should have coordinates of
// (225/900, 575/1600) 225/900 x 225/1600
// = (0.25,0.359375) 0.25 x 0.140625
let convertedRect = normalizedSquareRect.convertFromNormalizedCenterCropSquare(toOriginalSize: originalSize)
XCTAssertEqual(convertedRect.minX, 0.25)
XCTAssertEqual(convertedRect.minY, 0.359375)
XCTAssertEqual(convertedRect.width, 0.25)
XCTAssertEqual(convertedRect.height, 0.140625)
}
}

View File

@ -63,6 +63,7 @@
E657B5862764307600134033 /* CIImage+StripeIdentityUnitTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = E657B5852764307600134033 /* CIImage+StripeIdentityUnitTest.swift */; };
E657B58827680D2A00134033 /* CIImage_StripeIdentitySnapshotTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = E657B58727680D2A00134033 /* CIImage_StripeIdentitySnapshotTest.swift */; };
E657B58A27681B7F00134033 /* ciimage_stripeidentity_test.png in Resources */ = {isa = PBXBuildFile; fileRef = E657B58927681B7F00134033 /* ciimage_stripeidentity_test.png */; };
E65DD28127A2634900625A17 /* MLMultiArray+StripeIdentity.swift in Sources */ = {isa = PBXBuildFile; fileRef = E65DD28027A2634900625A17 /* MLMultiArray+StripeIdentity.swift */; };
E662FC8B278E2EC2005B0DAD /* ListItemView.swift in Sources */ = {isa = PBXBuildFile; fileRef = E662FC8A278E2EC2005B0DAD /* ListItemView.swift */; };
E662FC8E278E3094005B0DAD /* ListViewSnapshotTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = E662FC8D278E3094005B0DAD /* ListViewSnapshotTest.swift */; };
E662FC90278E4636005B0DAD /* DocumentFileUploadViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = E662FC8F278E4636005B0DAD /* DocumentFileUploadViewController.swift */; };
@ -101,6 +102,7 @@
E6AF1EE9269FDA020091BE99 /* VerificationFlowWebViewSnapshotTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = E6AF1EE8269FDA020091BE99 /* VerificationFlowWebViewSnapshotTests.swift */; };
E6B3F63526FBFEA800963EAB /* StripeUICore.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = E6B3F63426FBFEA800963EAB /* StripeUICore.framework */; };
E6B3F64526FD97E600963EAB /* IdentityElementsFactory.swift in Sources */ = {isa = PBXBuildFile; fileRef = E6B3F64426FD97E600963EAB /* IdentityElementsFactory.swift */; };
E6C7BB1D27A8B700000807A6 /* NonMaxSuppression.swift in Sources */ = {isa = PBXBuildFile; fileRef = E6C7BB1C27A8B700000807A6 /* NonMaxSuppression.swift */; };
E6CC14B8269E22060078837F /* StripeCore.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = E6CC14B7269E22060078837F /* StripeCore.framework */; };
E6E146B126950E1E007BDCD8 /* StripeIdentity.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = E6E146A726950E1E007BDCD8 /* StripeIdentity.framework */; };
E6E146B826950E1E007BDCD8 /* StripeIdentity.h in Headers */ = {isa = PBXBuildFile; fileRef = E6E146AA26950E1E007BDCD8 /* StripeIdentity.h */; settings = {ATTRIBUTES = (Public, ); }; };
@ -229,6 +231,7 @@
E657B58727680D2A00134033 /* CIImage_StripeIdentitySnapshotTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CIImage_StripeIdentitySnapshotTest.swift; sourceTree = "<group>"; };
E657B58927681B7F00134033 /* ciimage_stripeidentity_test.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = ciimage_stripeidentity_test.png; sourceTree = "<group>"; };
E657B58F27683D5700134033 /* CIImage_StripeIdentitySnapshotTest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = CIImage_StripeIdentitySnapshotTest.png; sourceTree = "<group>"; };
E65DD28027A2634900625A17 /* MLMultiArray+StripeIdentity.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "MLMultiArray+StripeIdentity.swift"; sourceTree = "<group>"; };
E662FC8A278E2EC2005B0DAD /* ListItemView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ListItemView.swift; sourceTree = "<group>"; };
E662FC8D278E3094005B0DAD /* ListViewSnapshotTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ListViewSnapshotTest.swift; sourceTree = "<group>"; };
E662FC8F278E4636005B0DAD /* DocumentFileUploadViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DocumentFileUploadViewController.swift; sourceTree = "<group>"; };
@ -267,6 +270,7 @@
E6AF1EE8269FDA020091BE99 /* VerificationFlowWebViewSnapshotTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = VerificationFlowWebViewSnapshotTests.swift; sourceTree = "<group>"; };
E6B3F63426FBFEA800963EAB /* StripeUICore.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = StripeUICore.framework; sourceTree = BUILT_PRODUCTS_DIR; };
E6B3F64426FD97E600963EAB /* IdentityElementsFactory.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = IdentityElementsFactory.swift; sourceTree = "<group>"; };
E6C7BB1C27A8B700000807A6 /* NonMaxSuppression.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = NonMaxSuppression.swift; sourceTree = "<group>"; };
E6CC14B7269E22060078837F /* StripeCore.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = StripeCore.framework; sourceTree = BUILT_PRODUCTS_DIR; };
E6E146A726950E1E007BDCD8 /* StripeIdentity.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = StripeIdentity.framework; sourceTree = BUILT_PRODUCTS_DIR; };
E6E146AA26950E1E007BDCD8 /* StripeIdentity.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = StripeIdentity.h; sourceTree = "<group>"; };
@ -333,6 +337,14 @@
path = "Mock Files";
sourceTree = "<group>";
};
E63B26CE27AB27FD00728F94 /* Helpers */ = {
isa = PBXGroup;
children = (
E6C7BB1C27A8B700000807A6 /* NonMaxSuppression.swift */,
);
path = Helpers;
sourceTree = "<group>";
};
E6465101269E1E0B002EC424 /* Resources */ = {
isa = PBXGroup;
children = (
@ -353,14 +365,22 @@
E646511E269E1F7C002EC424 /* Helpers */ = {
isa = PBXGroup;
children = (
E666A529273A0951001DE130 /* Image.swift */,
E6465121269E202A002EC424 /* STPLocalizedString.swift */,
E606936E2702A2C100742859 /* String+Localized.swift */,
E646511F269E1F8F002EC424 /* StripeIdentityBundleLocator.swift */,
E666A529273A0951001DE130 /* Image.swift */,
);
path = Helpers;
sourceTree = "<group>";
};
E651289427A9BD4800586A1F /* ML */ = {
isa = PBXGroup;
children = (
E63B26CE27AB27FD00728F94 /* Helpers */,
);
path = ML;
sourceTree = "<group>";
};
E6548DCC27276E2300F399B2 /* API Bindings */ = {
isa = PBXGroup;
children = (
@ -443,6 +463,7 @@
E6548F4A2731D6B200F399B2 /* Coordinators */,
E6200DBF2745E8F400A06A8E /* IdentityFlowNavigationController.swift */,
F311D59727A275E800AD8123 /* IdentityUI.swift */,
E651289427A9BD4800586A1F /* ML */,
E6548F42272CAC0500F399B2 /* ViewControllers */,
E6548F40272CAC0500F399B2 /* Views */,
);
@ -522,6 +543,7 @@
E657B58127642FD500134033 /* Categories */ = {
isa = PBXGroup;
children = (
E65DD28027A2634900625A17 /* MLMultiArray+StripeIdentity.swift */,
E657B58227642FEC00134033 /* CIImage+StripeIdentity.swift */,
);
path = Categories;
@ -935,6 +957,7 @@
E6548F6E2731E8C500F399B2 /* VerificationPageDataIDDocument.swift in Sources */,
E6548E5F27279DFF00F399B2 /* VerificationPageRequirements.swift in Sources */,
E6465120269E1F8F002EC424 /* StripeIdentityBundleLocator.swift in Sources */,
E65DD28127A2634900625A17 /* MLMultiArray+StripeIdentity.swift in Sources */,
E666A5252739EFEE001DE130 /* DocumentCaptureViewController.swift in Sources */,
E6548E5D27279DFF00F399B2 /* VerificationPageStaticContentTextPage.swift in Sources */,
E6548F49272CC50900F399B2 /* BiometricConsentViewController.swift in Sources */,
@ -950,6 +973,7 @@
E6548EEA2729025000F399B2 /* StripeCore+Import.swift in Sources */,
E6AF1ECE269FD7990091BE99 /* IdentityVerificationSheet.swift in Sources */,
E6AF1ECA269FD7990091BE99 /* VerificationSheetAnalytics.swift in Sources */,
E6C7BB1D27A8B700000807A6 /* NonMaxSuppression.swift in Sources */,
E6548F5F2731E4E500F399B2 /* VerificationPageDataUpdate.swift in Sources */,
E627AEF5275851640048F88D /* DocumentCaptureView.swift in Sources */,
E6548F522731D9B400F399B2 /* VerificationPageDataRequirements.swift in Sources */,

View File

@ -0,0 +1,15 @@
//
// MLMultiArray+StripeIdentity.swift
// StripeIdentity
//
// Created by Mel Ludowise on 1/26/22.
//
import Foundation
import CoreML
extension MLMultiArray {
subscript(key: [Int]) -> NSNumber {
return self[key.map { NSNumber(value: $0) }]
}
}

View File

@ -0,0 +1,188 @@
/*
Taken from https://github.com/hollance/CoreMLHelpers/blob/master/CoreMLHelpers/NonMaxSuppression.swift
Copyright (c) 2017-2019 M.I. Hollemans
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to
deal in the Software without restriction, including without limitation the
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
IN THE SOFTWARE.
*/
import Foundation
import Accelerate
protocol MLBoundingBox {
/** Index of the predicted class. */
var classIndex: Int { get }
/** Confidence score. */
var score: Float { get }
/** Normalized coordinates between 0 and 1. */
var rect: CGRect { get }
}
/**
Computes intersection-over-union overlap between two bounding boxes.
*/
func IOU(_ a: CGRect, _ b: CGRect) -> Float {
let areaA = a.width * a.height
if areaA <= 0 { return 0 }
let areaB = b.width * b.height
if areaB <= 0 { return 0 }
let intersectionMinX = max(a.minX, b.minX)
let intersectionMinY = max(a.minY, b.minY)
let intersectionMaxX = min(a.maxX, b.maxX)
let intersectionMaxY = min(a.maxY, b.maxY)
let intersectionArea = max(intersectionMaxY - intersectionMinY, 0) *
max(intersectionMaxX - intersectionMinX, 0)
return Float(intersectionArea / (areaA + areaB - intersectionArea))
}
/**
Removes bounding boxes that overlap too much with other boxes that have
a higher score.
*/
func nonMaxSuppression(boundingBoxes: [MLBoundingBox],
iouThreshold: Float,
maxBoxes: Int) -> [Int] {
return nonMaxSuppression(boundingBoxes: boundingBoxes,
indices: Array(boundingBoxes.indices),
iouThreshold: iouThreshold,
maxBoxes: maxBoxes)
}
/**
Removes bounding boxes that overlap too much with other boxes that have
a higher score.
Based on code from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/non_max_suppression_op.cc
- Note: This version of NMS ignores the class of the bounding boxes. Since it
selects the bounding boxes in a greedy fashion, if a certain class has many
boxes that are selected, then it is possible none of the boxes of the other
classes get selected.
- Parameters:
- boundingBoxes: an array of bounding boxes and their scores
- indices: which predictions to look at
- iouThreshold: used to decide whether boxes overlap too much
- maxBoxes: the maximum number of boxes that will be selected
- Returns: the array indices of the selected bounding boxes
*/
func nonMaxSuppression(boundingBoxes: [MLBoundingBox],
indices: [Int],
iouThreshold: Float,
maxBoxes: Int) -> [Int] {
// Sort the boxes based on their confidence scores, from high to low.
let sortedIndices = indices.sorted { boundingBoxes[$0].score > boundingBoxes[$1].score }
var selected: [Int] = []
// Loop through the bounding boxes, from highest score to lowest score,
// and determine whether or not to keep each box.
for i in 0..<sortedIndices.count {
if selected.count >= maxBoxes { break }
var shouldSelect = true
let boxA = boundingBoxes[sortedIndices[i]]
// Does the current box overlap one of the selected boxes more than the
// given threshold amount? Then it's too similar, so don't keep it.
for j in 0..<selected.count {
let boxB = boundingBoxes[selected[j]]
if IOU(boxA.rect, boxB.rect) > iouThreshold {
shouldSelect = false
break
}
}
// This bounding box did not overlap too much with any previously selected
// bounding box, so we'll keep it.
if shouldSelect {
selected.append(sortedIndices[i])
}
}
return selected
}
/**
Multi-class version of non maximum suppression.
Where `nonMaxSuppression()` does not look at the class of the predictions at
all, the multi-class version first selects the best bounding boxes for each
class, and then keeps the best ones of those.
With this method you can usually expect to see at least one bounding box for
each class (unless all the scores for a given class are really low).
Based on code from: https://github.com/tensorflow/models/blob/master/object_detection/core/post_processing.py
- Parameters:
- numClasses: the number of classes
- boundingBoxes: an array of bounding boxes and their scores
- scoreThreshold: used to only keep bounding boxes with a high enough score
- iouThreshold: used to decide whether boxes overlap too much
- maxPerClass: the maximum number of boxes that will be selected per class
- maxTotal: maximum number of boxes that will be selected over all classes
- Returns: the array indices of the selected bounding boxes
*/
func nonMaxSuppressionMultiClass(numClasses: Int,
boundingBoxes: [MLBoundingBox],
scoreThreshold: Float,
iouThreshold: Float,
maxPerClass: Int,
maxTotal: Int) -> [Int] {
var selectedBoxes: [Int] = []
// Look at all the classes one-by-one.
for c in 0..<numClasses {
var filteredBoxes = [Int]()
// Look at every bounding box for this class.
for p in 0..<boundingBoxes.count {
let prediction = boundingBoxes[p]
if prediction.classIndex == c {
// Only keep the box if its score is over the threshold.
if prediction.score > scoreThreshold {
filteredBoxes.append(p)
}
}
}
// Only keep the best bounding boxes for this class.
let nmsBoxes = nonMaxSuppression(boundingBoxes: boundingBoxes,
indices: filteredBoxes,
iouThreshold: iouThreshold,
maxBoxes: maxPerClass)
// Add the indices of the surviving boxes to the big list.
selectedBoxes.append(contentsOf: nmsBoxes)
}
// Sort all the surviving boxes by score and only keep the best ones.
let sortedBoxes = selectedBoxes.sorted { boundingBoxes[$0].score > boundingBoxes[$1].score }
return Array(sortedBoxes.prefix(maxTotal))
}