IDPROD-2482 part 1: Utils needed for ML IDDetector (#693)
Adds utilities needed by the IDDetector ML model including: - CGRect helpers - MLMultiArray helper - A non-maximum suppression algorithm copied from https://github.com/hollance/CoreMLHelpers
This commit is contained in:
parent
327dae826b
commit
054faf5f88
|
@ -9,6 +9,10 @@ import Foundation
|
|||
import CoreGraphics
|
||||
|
||||
@_spi(STP) public extension CGRect {
|
||||
|
||||
/// Represents the bounds of a normalized coordinate system with range from (0,0) to (1,1)
|
||||
static let normalizedBounds = CGRect(x: 0, y: 0, width: 1, height: 1)
|
||||
|
||||
/**
|
||||
- Returns: A `CGRect` that has its y-coordinates inverted between the
|
||||
upper-left corner and lower-left corner.
|
||||
|
@ -26,4 +30,54 @@ import CoreGraphics
|
|||
height: height
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
Converts a rectangle that's using a normalized coordinate system from a
|
||||
center-crop coordinate system to an un-cropped coordinate system
|
||||
|
||||
Example, if the original size has a portrait aspect ratio, center-cropping
|
||||
the rect will result in the square area:
|
||||
```
|
||||
+---------+
|
||||
| |
|
||||
|---------|
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
|---------|
|
||||
| |
|
||||
+---------+
|
||||
```
|
||||
|
||||
This method converts the rect's coordinate relative to the center-cropped
|
||||
area into coordinates relative to the original un-cropped area:
|
||||
```
|
||||
+---------+
|
||||
| |
|
||||
+---------+ | |
|
||||
| +--+ | | +--+ |
|
||||
| | | | --> | | | |
|
||||
| +--+ | | +--+ |
|
||||
+---------+ | |
|
||||
| |
|
||||
+---------+
|
||||
```
|
||||
|
||||
- Parameters:
|
||||
- size: The original size of the un-cropped area.
|
||||
*/
|
||||
func convertFromNormalizedCenterCropSquare(
|
||||
toOriginalSize originalSize: CGSize
|
||||
) -> CGRect {
|
||||
let croppedWidth = min(originalSize.width, originalSize.height)
|
||||
let scaleX = croppedWidth / originalSize.width
|
||||
let scaleY = croppedWidth / originalSize.height
|
||||
|
||||
return CGRect(
|
||||
x: (minX - 0.5) * scaleX + 0.5,
|
||||
y: (minY - 0.5) * scaleY + 0.5,
|
||||
width: width * scaleX,
|
||||
height: height * scaleY
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,4 +19,51 @@ class CGRect_StripeCameraCoreTest: XCTestCase {
|
|||
XCTAssertEqual(invertedRect.width, 0.5)
|
||||
XCTAssertEqual(invertedRect.height, 0.5)
|
||||
}
|
||||
|
||||
func testConvertFromNormalizedCenterCropSquareLandscape() {
|
||||
// Centered-square rect corresponds to (350,0) 900x900
|
||||
let originalSize = CGSize(width: 1600, height: 900)
|
||||
|
||||
|
||||
// Corresponds to actual coordinates of
|
||||
// (350+900*0.25,900*0.25) 900*0.25 x 900*0.25
|
||||
// = (575,225) 225x225
|
||||
let normalizedSquareRect = CGRect(x: 0.25, y: 0.25, width: 0.25, height: 0.25)
|
||||
|
||||
|
||||
// Should have coordinates of
|
||||
// (575/1600, 225/900) 225/1600 x 225/900
|
||||
// = (0.359375,0.25) 0.140625 x 0.25
|
||||
let convertedRect = normalizedSquareRect.convertFromNormalizedCenterCropSquare(toOriginalSize: originalSize)
|
||||
|
||||
|
||||
XCTAssertEqual(convertedRect.minX, 0.359375)
|
||||
XCTAssertEqual(convertedRect.minY, 0.25)
|
||||
XCTAssertEqual(convertedRect.width, 0.140625)
|
||||
XCTAssertEqual(convertedRect.height, 0.25)
|
||||
}
|
||||
|
||||
func testConvertFromNormalizedCenterCropSquarePortrait() {
|
||||
// Centered-square rect corresponds to (0,350) 900x900
|
||||
let originalSize = CGSize(width: 900, height: 1600)
|
||||
|
||||
|
||||
// Corresponds to actual coordinates of
|
||||
// (900*0.25,350+900*0.25) 900*0.25 x 900*0.25
|
||||
// = (225,575) 225x225
|
||||
let normalizedSquareRect = CGRect(x: 0.25, y: 0.25, width: 0.25, height: 0.25)
|
||||
|
||||
|
||||
// Should have coordinates of
|
||||
// (225/900, 575/1600) 225/900 x 225/1600
|
||||
// = (0.25,0.359375) 0.25 x 0.140625
|
||||
let convertedRect = normalizedSquareRect.convertFromNormalizedCenterCropSquare(toOriginalSize: originalSize)
|
||||
|
||||
|
||||
XCTAssertEqual(convertedRect.minX, 0.25)
|
||||
XCTAssertEqual(convertedRect.minY, 0.359375)
|
||||
XCTAssertEqual(convertedRect.width, 0.25)
|
||||
XCTAssertEqual(convertedRect.height, 0.140625)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -63,6 +63,7 @@
|
|||
E657B5862764307600134033 /* CIImage+StripeIdentityUnitTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = E657B5852764307600134033 /* CIImage+StripeIdentityUnitTest.swift */; };
|
||||
E657B58827680D2A00134033 /* CIImage_StripeIdentitySnapshotTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = E657B58727680D2A00134033 /* CIImage_StripeIdentitySnapshotTest.swift */; };
|
||||
E657B58A27681B7F00134033 /* ciimage_stripeidentity_test.png in Resources */ = {isa = PBXBuildFile; fileRef = E657B58927681B7F00134033 /* ciimage_stripeidentity_test.png */; };
|
||||
E65DD28127A2634900625A17 /* MLMultiArray+StripeIdentity.swift in Sources */ = {isa = PBXBuildFile; fileRef = E65DD28027A2634900625A17 /* MLMultiArray+StripeIdentity.swift */; };
|
||||
E662FC8B278E2EC2005B0DAD /* ListItemView.swift in Sources */ = {isa = PBXBuildFile; fileRef = E662FC8A278E2EC2005B0DAD /* ListItemView.swift */; };
|
||||
E662FC8E278E3094005B0DAD /* ListViewSnapshotTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = E662FC8D278E3094005B0DAD /* ListViewSnapshotTest.swift */; };
|
||||
E662FC90278E4636005B0DAD /* DocumentFileUploadViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = E662FC8F278E4636005B0DAD /* DocumentFileUploadViewController.swift */; };
|
||||
|
@ -101,6 +102,7 @@
|
|||
E6AF1EE9269FDA020091BE99 /* VerificationFlowWebViewSnapshotTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = E6AF1EE8269FDA020091BE99 /* VerificationFlowWebViewSnapshotTests.swift */; };
|
||||
E6B3F63526FBFEA800963EAB /* StripeUICore.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = E6B3F63426FBFEA800963EAB /* StripeUICore.framework */; };
|
||||
E6B3F64526FD97E600963EAB /* IdentityElementsFactory.swift in Sources */ = {isa = PBXBuildFile; fileRef = E6B3F64426FD97E600963EAB /* IdentityElementsFactory.swift */; };
|
||||
E6C7BB1D27A8B700000807A6 /* NonMaxSuppression.swift in Sources */ = {isa = PBXBuildFile; fileRef = E6C7BB1C27A8B700000807A6 /* NonMaxSuppression.swift */; };
|
||||
E6CC14B8269E22060078837F /* StripeCore.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = E6CC14B7269E22060078837F /* StripeCore.framework */; };
|
||||
E6E146B126950E1E007BDCD8 /* StripeIdentity.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = E6E146A726950E1E007BDCD8 /* StripeIdentity.framework */; };
|
||||
E6E146B826950E1E007BDCD8 /* StripeIdentity.h in Headers */ = {isa = PBXBuildFile; fileRef = E6E146AA26950E1E007BDCD8 /* StripeIdentity.h */; settings = {ATTRIBUTES = (Public, ); }; };
|
||||
|
@ -229,6 +231,7 @@
|
|||
E657B58727680D2A00134033 /* CIImage_StripeIdentitySnapshotTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CIImage_StripeIdentitySnapshotTest.swift; sourceTree = "<group>"; };
|
||||
E657B58927681B7F00134033 /* ciimage_stripeidentity_test.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = ciimage_stripeidentity_test.png; sourceTree = "<group>"; };
|
||||
E657B58F27683D5700134033 /* CIImage_StripeIdentitySnapshotTest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = CIImage_StripeIdentitySnapshotTest.png; sourceTree = "<group>"; };
|
||||
E65DD28027A2634900625A17 /* MLMultiArray+StripeIdentity.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "MLMultiArray+StripeIdentity.swift"; sourceTree = "<group>"; };
|
||||
E662FC8A278E2EC2005B0DAD /* ListItemView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ListItemView.swift; sourceTree = "<group>"; };
|
||||
E662FC8D278E3094005B0DAD /* ListViewSnapshotTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ListViewSnapshotTest.swift; sourceTree = "<group>"; };
|
||||
E662FC8F278E4636005B0DAD /* DocumentFileUploadViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DocumentFileUploadViewController.swift; sourceTree = "<group>"; };
|
||||
|
@ -267,6 +270,7 @@
|
|||
E6AF1EE8269FDA020091BE99 /* VerificationFlowWebViewSnapshotTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = VerificationFlowWebViewSnapshotTests.swift; sourceTree = "<group>"; };
|
||||
E6B3F63426FBFEA800963EAB /* StripeUICore.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = StripeUICore.framework; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
E6B3F64426FD97E600963EAB /* IdentityElementsFactory.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = IdentityElementsFactory.swift; sourceTree = "<group>"; };
|
||||
E6C7BB1C27A8B700000807A6 /* NonMaxSuppression.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = NonMaxSuppression.swift; sourceTree = "<group>"; };
|
||||
E6CC14B7269E22060078837F /* StripeCore.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = StripeCore.framework; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
E6E146A726950E1E007BDCD8 /* StripeIdentity.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = StripeIdentity.framework; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
E6E146AA26950E1E007BDCD8 /* StripeIdentity.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = StripeIdentity.h; sourceTree = "<group>"; };
|
||||
|
@ -333,6 +337,14 @@
|
|||
path = "Mock Files";
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E63B26CE27AB27FD00728F94 /* Helpers */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
E6C7BB1C27A8B700000807A6 /* NonMaxSuppression.swift */,
|
||||
);
|
||||
path = Helpers;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E6465101269E1E0B002EC424 /* Resources */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
|
@ -353,14 +365,22 @@
|
|||
E646511E269E1F7C002EC424 /* Helpers */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
E666A529273A0951001DE130 /* Image.swift */,
|
||||
E6465121269E202A002EC424 /* STPLocalizedString.swift */,
|
||||
E606936E2702A2C100742859 /* String+Localized.swift */,
|
||||
E646511F269E1F8F002EC424 /* StripeIdentityBundleLocator.swift */,
|
||||
E666A529273A0951001DE130 /* Image.swift */,
|
||||
);
|
||||
path = Helpers;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E651289427A9BD4800586A1F /* ML */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
E63B26CE27AB27FD00728F94 /* Helpers */,
|
||||
);
|
||||
path = ML;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E6548DCC27276E2300F399B2 /* API Bindings */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
|
@ -443,6 +463,7 @@
|
|||
E6548F4A2731D6B200F399B2 /* Coordinators */,
|
||||
E6200DBF2745E8F400A06A8E /* IdentityFlowNavigationController.swift */,
|
||||
F311D59727A275E800AD8123 /* IdentityUI.swift */,
|
||||
E651289427A9BD4800586A1F /* ML */,
|
||||
E6548F42272CAC0500F399B2 /* ViewControllers */,
|
||||
E6548F40272CAC0500F399B2 /* Views */,
|
||||
);
|
||||
|
@ -522,6 +543,7 @@
|
|||
E657B58127642FD500134033 /* Categories */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
E65DD28027A2634900625A17 /* MLMultiArray+StripeIdentity.swift */,
|
||||
E657B58227642FEC00134033 /* CIImage+StripeIdentity.swift */,
|
||||
);
|
||||
path = Categories;
|
||||
|
@ -935,6 +957,7 @@
|
|||
E6548F6E2731E8C500F399B2 /* VerificationPageDataIDDocument.swift in Sources */,
|
||||
E6548E5F27279DFF00F399B2 /* VerificationPageRequirements.swift in Sources */,
|
||||
E6465120269E1F8F002EC424 /* StripeIdentityBundleLocator.swift in Sources */,
|
||||
E65DD28127A2634900625A17 /* MLMultiArray+StripeIdentity.swift in Sources */,
|
||||
E666A5252739EFEE001DE130 /* DocumentCaptureViewController.swift in Sources */,
|
||||
E6548E5D27279DFF00F399B2 /* VerificationPageStaticContentTextPage.swift in Sources */,
|
||||
E6548F49272CC50900F399B2 /* BiometricConsentViewController.swift in Sources */,
|
||||
|
@ -950,6 +973,7 @@
|
|||
E6548EEA2729025000F399B2 /* StripeCore+Import.swift in Sources */,
|
||||
E6AF1ECE269FD7990091BE99 /* IdentityVerificationSheet.swift in Sources */,
|
||||
E6AF1ECA269FD7990091BE99 /* VerificationSheetAnalytics.swift in Sources */,
|
||||
E6C7BB1D27A8B700000807A6 /* NonMaxSuppression.swift in Sources */,
|
||||
E6548F5F2731E4E500F399B2 /* VerificationPageDataUpdate.swift in Sources */,
|
||||
E627AEF5275851640048F88D /* DocumentCaptureView.swift in Sources */,
|
||||
E6548F522731D9B400F399B2 /* VerificationPageDataRequirements.swift in Sources */,
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
//
|
||||
// MLMultiArray+StripeIdentity.swift
|
||||
// StripeIdentity
|
||||
//
|
||||
// Created by Mel Ludowise on 1/26/22.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import CoreML
|
||||
|
||||
extension MLMultiArray {
|
||||
subscript(key: [Int]) -> NSNumber {
|
||||
return self[key.map { NSNumber(value: $0) }]
|
||||
}
|
||||
}
|
|
@ -0,0 +1,188 @@
|
|||
/*
|
||||
|
||||
Taken from https://github.com/hollance/CoreMLHelpers/blob/master/CoreMLHelpers/NonMaxSuppression.swift
|
||||
|
||||
Copyright (c) 2017-2019 M.I. Hollemans
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to
|
||||
deal in the Software without restriction, including without limitation the
|
||||
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
sell copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
import Foundation
|
||||
import Accelerate
|
||||
|
||||
protocol MLBoundingBox {
|
||||
/** Index of the predicted class. */
|
||||
var classIndex: Int { get }
|
||||
|
||||
/** Confidence score. */
|
||||
var score: Float { get }
|
||||
|
||||
/** Normalized coordinates between 0 and 1. */
|
||||
var rect: CGRect { get }
|
||||
}
|
||||
|
||||
/**
|
||||
Computes intersection-over-union overlap between two bounding boxes.
|
||||
*/
|
||||
func IOU(_ a: CGRect, _ b: CGRect) -> Float {
|
||||
let areaA = a.width * a.height
|
||||
if areaA <= 0 { return 0 }
|
||||
|
||||
let areaB = b.width * b.height
|
||||
if areaB <= 0 { return 0 }
|
||||
|
||||
let intersectionMinX = max(a.minX, b.minX)
|
||||
let intersectionMinY = max(a.minY, b.minY)
|
||||
let intersectionMaxX = min(a.maxX, b.maxX)
|
||||
let intersectionMaxY = min(a.maxY, b.maxY)
|
||||
let intersectionArea = max(intersectionMaxY - intersectionMinY, 0) *
|
||||
max(intersectionMaxX - intersectionMinX, 0)
|
||||
return Float(intersectionArea / (areaA + areaB - intersectionArea))
|
||||
}
|
||||
|
||||
/**
|
||||
Removes bounding boxes that overlap too much with other boxes that have
|
||||
a higher score.
|
||||
*/
|
||||
func nonMaxSuppression(boundingBoxes: [MLBoundingBox],
|
||||
iouThreshold: Float,
|
||||
maxBoxes: Int) -> [Int] {
|
||||
return nonMaxSuppression(boundingBoxes: boundingBoxes,
|
||||
indices: Array(boundingBoxes.indices),
|
||||
iouThreshold: iouThreshold,
|
||||
maxBoxes: maxBoxes)
|
||||
}
|
||||
|
||||
/**
|
||||
Removes bounding boxes that overlap too much with other boxes that have
|
||||
a higher score.
|
||||
|
||||
Based on code from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/non_max_suppression_op.cc
|
||||
|
||||
- Note: This version of NMS ignores the class of the bounding boxes. Since it
|
||||
selects the bounding boxes in a greedy fashion, if a certain class has many
|
||||
boxes that are selected, then it is possible none of the boxes of the other
|
||||
classes get selected.
|
||||
|
||||
- Parameters:
|
||||
- boundingBoxes: an array of bounding boxes and their scores
|
||||
- indices: which predictions to look at
|
||||
- iouThreshold: used to decide whether boxes overlap too much
|
||||
- maxBoxes: the maximum number of boxes that will be selected
|
||||
|
||||
- Returns: the array indices of the selected bounding boxes
|
||||
*/
|
||||
func nonMaxSuppression(boundingBoxes: [MLBoundingBox],
|
||||
indices: [Int],
|
||||
iouThreshold: Float,
|
||||
maxBoxes: Int) -> [Int] {
|
||||
|
||||
// Sort the boxes based on their confidence scores, from high to low.
|
||||
let sortedIndices = indices.sorted { boundingBoxes[$0].score > boundingBoxes[$1].score }
|
||||
|
||||
var selected: [Int] = []
|
||||
|
||||
// Loop through the bounding boxes, from highest score to lowest score,
|
||||
// and determine whether or not to keep each box.
|
||||
for i in 0..<sortedIndices.count {
|
||||
if selected.count >= maxBoxes { break }
|
||||
|
||||
var shouldSelect = true
|
||||
let boxA = boundingBoxes[sortedIndices[i]]
|
||||
|
||||
// Does the current box overlap one of the selected boxes more than the
|
||||
// given threshold amount? Then it's too similar, so don't keep it.
|
||||
for j in 0..<selected.count {
|
||||
let boxB = boundingBoxes[selected[j]]
|
||||
if IOU(boxA.rect, boxB.rect) > iouThreshold {
|
||||
shouldSelect = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// This bounding box did not overlap too much with any previously selected
|
||||
// bounding box, so we'll keep it.
|
||||
if shouldSelect {
|
||||
selected.append(sortedIndices[i])
|
||||
}
|
||||
}
|
||||
|
||||
return selected
|
||||
}
|
||||
|
||||
/**
|
||||
Multi-class version of non maximum suppression.
|
||||
|
||||
Where `nonMaxSuppression()` does not look at the class of the predictions at
|
||||
all, the multi-class version first selects the best bounding boxes for each
|
||||
class, and then keeps the best ones of those.
|
||||
|
||||
With this method you can usually expect to see at least one bounding box for
|
||||
each class (unless all the scores for a given class are really low).
|
||||
|
||||
Based on code from: https://github.com/tensorflow/models/blob/master/object_detection/core/post_processing.py
|
||||
|
||||
- Parameters:
|
||||
- numClasses: the number of classes
|
||||
- boundingBoxes: an array of bounding boxes and their scores
|
||||
- scoreThreshold: used to only keep bounding boxes with a high enough score
|
||||
- iouThreshold: used to decide whether boxes overlap too much
|
||||
- maxPerClass: the maximum number of boxes that will be selected per class
|
||||
- maxTotal: maximum number of boxes that will be selected over all classes
|
||||
|
||||
- Returns: the array indices of the selected bounding boxes
|
||||
*/
|
||||
func nonMaxSuppressionMultiClass(numClasses: Int,
|
||||
boundingBoxes: [MLBoundingBox],
|
||||
scoreThreshold: Float,
|
||||
iouThreshold: Float,
|
||||
maxPerClass: Int,
|
||||
maxTotal: Int) -> [Int] {
|
||||
var selectedBoxes: [Int] = []
|
||||
|
||||
// Look at all the classes one-by-one.
|
||||
for c in 0..<numClasses {
|
||||
var filteredBoxes = [Int]()
|
||||
|
||||
// Look at every bounding box for this class.
|
||||
for p in 0..<boundingBoxes.count {
|
||||
let prediction = boundingBoxes[p]
|
||||
if prediction.classIndex == c {
|
||||
|
||||
// Only keep the box if its score is over the threshold.
|
||||
if prediction.score > scoreThreshold {
|
||||
filteredBoxes.append(p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only keep the best bounding boxes for this class.
|
||||
let nmsBoxes = nonMaxSuppression(boundingBoxes: boundingBoxes,
|
||||
indices: filteredBoxes,
|
||||
iouThreshold: iouThreshold,
|
||||
maxBoxes: maxPerClass)
|
||||
|
||||
// Add the indices of the surviving boxes to the big list.
|
||||
selectedBoxes.append(contentsOf: nmsBoxes)
|
||||
}
|
||||
|
||||
// Sort all the surviving boxes by score and only keep the best ones.
|
||||
let sortedBoxes = selectedBoxes.sorted { boundingBoxes[$0].score > boundingBoxes[$1].score }
|
||||
return Array(sortedBoxes.prefix(maxTotal))
|
||||
}
|
Loading…
Reference in New Issue