[MLIR][Presburger] Add support for piece-wise multi-affine functions

Add the class MultiAffineFunction which represents functions whose domain is an
IntegerPolyhedron and which produce an output given by a tuple of affine
expressions in the IntegerPolyhedron's ids.

Also add support for piece-wise MultiAffineFunctions, which are defined on a
union of IntegerPolyhedrons, and may have different output affine expressions
on each IntegerPolyhedron. Thus the function is affine on each individual
IntegerPolyhedron piece in the domain.

This is part of a series of patches leading up to parametric integer programming.

Depends on D118778.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D118779
This commit is contained in:
Arjun P 2022-02-08 00:31:27 +05:30
parent 570471199b
commit d5a2944219
7 changed files with 599 additions and 19 deletions

View File

@ -56,6 +56,7 @@ public:
enum class Kind {
FlatAffineConstraints,
FlatAffineValueConstraints,
MultiAffineFunction,
IntegerPolyhedron
};
@ -194,6 +195,11 @@ public:
/// Adds an equality from the coefficients specified in `eq`.
void addEquality(ArrayRef<int64_t> eq);
/// Eliminate the `posB^th` local identifier, replacing every instance of it
/// with the `posA^th` local identifier. This should be used when the two
/// local variables are known to always take the same values.
virtual void eliminateRedundantLocalId(unsigned posA, unsigned posB);
/// Removes identifiers of the specified kind with the specified pos (or
/// within the specified range) from the system. The specified location is
/// relative to the first identifier of the specified kind.
@ -273,6 +279,9 @@ public:
/// Returns true if the given point satisfies the constraints, or false
/// otherwise.
///
/// Note: currently, if the polyhedron contains local ids, the values of
/// the local ids must also be provided.
bool containsPoint(ArrayRef<int64_t> point) const;
/// Find equality and pairs of inequality contraints identified by their

View File

@ -0,0 +1,195 @@
//===- PWMAFunction.h - MLIR PWMAFunction Class------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Support for piece-wise multi-affine functions. These are functions that are
// defined on a domain that is a union of IntegerPolyhedrons, and on each domain
// the value of the function is a tuple of integers, with each value in the
// tuple being an affine expression in the ids of the IntegerPolyhedron.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
#define MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
#include "mlir/Analysis/Presburger/IntegerPolyhedron.h"
#include "mlir/Analysis/Presburger/PresburgerSet.h"
namespace mlir {
/// This class represents a multi-affine function whose domain is given by an
/// IntegerPolyhedron. This can be thought of as an IntegerPolyhedron with a
/// tuple of integer values attached to every point in the polyhedron, with the
/// value of each element of the tuple given by an affine expression in the ids
/// of the polyhedron. For example we could have the domain
///
/// (x, y) : (x >= 5, y >= x)
///
/// and a tuple of three integers defined at every point in the polyhedron:
///
/// (x, y) -> (x + 2, 2*x - 3y + 5, 2*x + y).
///
/// In this way every point in the polyhedron has a tuple of integers associated
/// with it. If the integer polyhedron has local ids, then the output
/// expressions can use them as well. The output expressions are represented as
/// a matrix with one row for every element in the output vector one column for
/// each id, and an extra column at the end for the constant term.
///
/// Checking equality of two such functions is supported, as well as finding the
/// value of the function at a specified point. Note that local ids in the
/// domain are not yet supported for finding the value at a point.
class MultiAffineFunction : protected IntegerPolyhedron {
public:
/// We use protected inheritance to avoid inheriting the whole public
/// interface of IntegerPolyhedron. These using declarations explicitly make
/// only the relevant functions part of the public interface.
using IntegerPolyhedron::getNumDimAndSymbolIds;
using IntegerPolyhedron::getNumDimIds;
using IntegerPolyhedron::getNumIds;
using IntegerPolyhedron::getNumLocalIds;
using IntegerPolyhedron::getNumSymbolIds;
MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output)
: IntegerPolyhedron(domain), output(output) {}
MultiAffineFunction(const Matrix &output, unsigned numDims,
unsigned numSymbols = 0, unsigned numLocals = 0)
: IntegerPolyhedron(numDims, numSymbols, numLocals), output(output) {}
~MultiAffineFunction() override = default;
Kind getKind() const override { return Kind::MultiAffineFunction; }
bool classof(const IntegerPolyhedron *poly) const {
return poly->getKind() == Kind::MultiAffineFunction;
}
unsigned getNumInputs() const { return getNumDimAndSymbolIds(); }
unsigned getNumOutputs() const { return output.getNumRows(); }
bool isConsistent() const { return output.getNumColumns() == numIds + 1; }
const IntegerPolyhedron &getDomain() const { return *this; }
bool hasCompatibleDimensions(const MultiAffineFunction &f) const;
/// Insert `num` identifiers of the specified kind at position `pos`.
/// Positions are relative to the kind of identifier. The coefficient columns
/// corresponding to the added identifiers are initialized to zero. Return the
/// absolute column position (i.e., not relative to the kind of identifier)
/// of the first added identifier.
unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
/// Swap the posA^th identifier with the posB^th identifier.
void swapId(unsigned posA, unsigned posB) override;
/// Remove the specified range of ids.
void removeIdRange(unsigned idStart, unsigned idLimit) override;
/// Eliminate the `posB^th` local identifier, replacing every instance of it
/// with the `posA^th` local identifier. This should be used when the two
/// local variables are known to always take the same values.
void eliminateRedundantLocalId(unsigned posA, unsigned posB) override;
/// Return whether the outputs of `this` and `other` agree wherever both
/// functions are defined, i.e., the outputs should be equal for all points in
/// the intersection of the domains.
bool isEqualWhereDomainsOverlap(MultiAffineFunction other) const;
/// Return whether the `this` and `other` are equal. This is the case if
/// they lie in the same space, i.e. have the same dimensions, and their
/// domains are identical and their outputs are equal on their domain.
bool isEqual(const MultiAffineFunction &other) const;
/// Get the value of the function at the specified point. If the point lies
/// outside the domain, an empty optional is returned.
///
/// Note: domains with local ids are not yet supported, and will assert-fail.
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
void print(raw_ostream &os) const;
void dump() const;
private:
/// The function's output is a tuple of integers, with the ith element of the
/// tuple defined by the affine expression given by the ith row of this output
/// matrix.
Matrix output;
};
/// This class represents a piece-wise MultiAffineFunction. This can be thought
/// of as a list of MultiAffineFunction with disjoint domains, with each having
/// their own affine expressions for their output tuples. For example, we could
/// have a function with two input variables (x, y), defined as
///
/// f(x, y) = (2*x + y, y - 4) if x >= 0, y >= 0
/// = (-2*x + y, y + 4) if x < 0, y < 0
/// = (4, 1) if x < 0, y >= 0
///
/// Note that the domains all have to be *disjoint*. Otherwise, the behaviour of
/// this class is undefined. The domains need not cover all possible points;
/// this represents a partial function and so could be undefined at some points.
///
/// As in PresburgerSets, the input ids are partitioned into dimension ids and
/// symbolic ids.
///
/// Support is provided to compare equality of two such functions as well as
/// finding the value of the function at a point. Note that local ids in the
/// piece are not supported for the latter.
class PWMAFunction {
public:
PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs)
: numDims(numDims), numSymbols(numSymbols), numOutputs(numOutputs) {
assert(numOutputs >= 1 && "The function must output something!");
}
void addPiece(const MultiAffineFunction &piece);
void addPiece(const IntegerPolyhedron &domain, const Matrix &output);
const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; }
unsigned getNumPieces() const { return pieces.size(); }
unsigned getNumOutputs() const { return numOutputs; }
unsigned getNumInputs() const { return numDims + numSymbols; }
unsigned getNumDimIds() const { return numDims; }
unsigned getNumSymbolIds() const { return numSymbols; }
MultiAffineFunction &getPiece(unsigned i) { return pieces[i]; }
/// Return the domain of this piece-wise MultiAffineFunction. This is the
/// union of the domains of all the pieces.
PresburgerSet getDomain() const;
/// Check whether the `this` and the given function have compatible
/// dimensions, i.e., the same number of dimension inputs, symbol inputs, and
/// outputs.
bool hasCompatibleDimensions(const MultiAffineFunction &f) const;
bool hasCompatibleDimensions(const PWMAFunction &f) const;
/// Return the value at the specified point and an empty optional if the
/// point does not lie in the domain.
///
/// Note: domains with local ids are not yet supported, and will assert-fail.
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
/// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether
/// they have the same dimensions, the same domain and they take the same
/// value at every point in the domain.
bool isEqual(const PWMAFunction &other) const;
void print(raw_ostream &os) const;
void dump() const;
private:
/// The list of pieces in this piece-wise MultiAffineFunction.
SmallVector<MultiAffineFunction, 4> pieces;
/// The number of dimensions ids in the domains.
unsigned numDims;
/// The number of symbol ids in the domains.
unsigned numSymbols;
/// The number of output ids.
unsigned numOutputs;
};
} // namespace mlir
#endif // MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H

View File

@ -3,6 +3,7 @@ add_mlir_library(MLIRPresburger
LinearTransform.cpp
Matrix.cpp
PresburgerSet.cpp
PWMAFunction.cpp
Simplex.cpp
Utils.cpp

View File

@ -1065,24 +1065,17 @@ void IntegerPolyhedron::removeRedundantConstraints() {
equalities.resizeVertically(pos);
}
/// Eliminate `pos2^th` local identifier, replacing its every instance with
/// `pos1^th` local identifier. This function is intended to be used to remove
/// redundancy when local variables at position `pos1` and `pos2` are restricted
/// to have the same value.
static void eliminateRedundantLocalId(IntegerPolyhedron &poly, unsigned pos1,
unsigned pos2) {
void IntegerPolyhedron::eliminateRedundantLocalId(unsigned posA,
unsigned posB) {
assert(posA < getNumLocalIds() && "Invalid local id position");
assert(posB < getNumLocalIds() && "Invalid local id position");
assert(pos1 < poly.getNumLocalIds() && "Invalid local id position");
assert(pos2 < poly.getNumLocalIds() && "Invalid local id position");
unsigned localOffset = poly.getNumDimAndSymbolIds();
pos1 += localOffset;
pos2 += localOffset;
for (unsigned i = 0, e = poly.getNumInequalities(); i < e; ++i)
poly.atIneq(i, pos1) += poly.atIneq(i, pos2);
for (unsigned i = 0, e = poly.getNumEqualities(); i < e; ++i)
poly.atEq(i, pos1) += poly.atEq(i, pos2);
poly.removeId(pos2);
unsigned localOffset = getIdKindOffset(IdKind::Local);
posA += localOffset;
posB += localOffset;
inequalities.addToColumn(posB, posA, 1);
equalities.addToColumn(posB, posA, 1);
removeId(posB);
}
/// Adds additional local ids to the sets such that they both have the union
@ -1129,8 +1122,8 @@ void IntegerPolyhedron::mergeLocalIds(IntegerPolyhedron &other) {
// Merge function that merges the local variables in both sets by treating
// them as the same identifier.
auto merge = [&polyA, &polyB](unsigned i, unsigned j) -> bool {
eliminateRedundantLocalId(polyA, i, j);
eliminateRedundantLocalId(polyB, i, j);
polyA.eliminateRedundantLocalId(i, j);
polyB.eliminateRedundantLocalId(i, j);
return true;
};

View File

@ -0,0 +1,198 @@
//===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/Simplex.h"
using namespace mlir;
// Return the result of subtracting the two given vectors pointwise.
// The vectors must be of the same size.
// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
ArrayRef<int64_t> vecB) {
assert(vecA.size() == vecB.size() &&
"Cannot subtract vectors of differing lengths!");
SmallVector<int64_t, 8> result;
result.reserve(vecA.size());
for (unsigned i = 0, e = vecA.size(); i < e; ++i)
result.push_back(vecA[i] - vecB[i]);
return result;
}
PresburgerSet PWMAFunction::getDomain() const {
PresburgerSet domain =
PresburgerSet::getEmptySet(getNumDimIds(), getNumSymbolIds());
for (const MultiAffineFunction &piece : pieces)
domain.unionPolyInPlace(piece.getDomain());
return domain;
}
Optional<SmallVector<int64_t, 8>>
MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
assert(getNumLocalIds() == 0 && "Local ids are not yet supported!");
assert(point.size() == getNumIds() && "Point has incorrect dimensionality!");
if (!getDomain().containsPoint(point))
return {};
// The point lies in the domain, so we need to compute the output value.
// The matrix `output` has an affine expression in the ith row, corresponding
// to the expression for the ith value in the output vector. The last column
// of the matrix contains the constant term. Let v be the input point with
// a 1 appended at the end. We can see that output * v gives the desired
// output vector.
SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
pointHomogenous.push_back(1);
SmallVector<int64_t, 8> result =
output.postMultiplyWithColumn(pointHomogenous);
assert(result.size() == getNumOutputs());
return result;
}
Optional<SmallVector<int64_t, 8>>
PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
assert(point.size() == getNumInputs() &&
"Point has incorrect dimensionality!");
for (const MultiAffineFunction &piece : pieces)
if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
return output;
return {};
}
void MultiAffineFunction::print(raw_ostream &os) const {
os << "Domain:";
IntegerPolyhedron::print(os);
os << "Output:\n";
output.print(os);
os << "\n";
}
void MultiAffineFunction::dump() const { print(llvm::errs()); }
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
return hasCompatibleDimensions(other) &&
getDomain().isEqual(other.getDomain()) &&
isEqualWhereDomainsOverlap(other);
}
unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos,
unsigned num) {
unsigned absolutePos = getIdKindOffset(kind) + pos;
output.insertColumns(absolutePos, num);
return IntegerPolyhedron::insertId(kind, pos, num);
}
void MultiAffineFunction::swapId(unsigned posA, unsigned posB) {
output.swapColumns(posA, posB);
IntegerPolyhedron::swapId(posA, posB);
}
void MultiAffineFunction::removeIdRange(unsigned idStart, unsigned idLimit) {
output.removeColumns(idStart, idLimit - idStart);
IntegerPolyhedron::removeIdRange(idStart, idLimit);
}
void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA,
unsigned posB) {
output.addToColumn(posB, posA, /*scale=*/1);
IntegerPolyhedron::eliminateRedundantLocalId(posA, posB);
}
bool MultiAffineFunction::isEqualWhereDomainsOverlap(
MultiAffineFunction other) const {
if (!hasCompatibleDimensions(other))
return false;
// `commonFunc` has the same output as `this`.
MultiAffineFunction commonFunc = *this;
// After this merge, `commonFunc` and `other` have the same local ids; they
// are merged.
commonFunc.mergeLocalIds(other);
// After this, the domain of `commonFunc` will be the intersection of the
// domains of `this` and `other`.
commonFunc.IntegerPolyhedron::append(other);
// `commonDomainMatching` contains the subset of the common domain
// where the outputs of `this` and `other` match.
//
// We want to add constraints equating the outputs of `this` and `other`.
// However, `this` may have difference local ids from `other`, whereas we
// need both to have the same locals. Accordingly, we use `commonFunc.output`
// in place of `this->output`, since `commonFunc` has the same output but also
// has its locals merged.
IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
commonDomainMatching.addEquality(
subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
// If the whole common domain is a subset of commonDomainMatching, then they
// are equal and the two functions match on the whole common domain.
return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
}
/// Two PWMAFunctions are equal if they have the same dimensionalities,
/// the same domain, and take the same value at every point in the domain.
bool PWMAFunction::isEqual(const PWMAFunction &other) const {
if (!hasCompatibleDimensions(other))
return false;
if (!this->getDomain().isEqual(other.getDomain()))
return false;
// Check if, whenever the domains of a piece of `this` and a piece of `other`
// overlap, they take the same output value. If `this` and `other` have the
// same domain (checked above), then this check passes iff the two functions
// have the same output at every point in the domain.
for (const MultiAffineFunction &aPiece : this->pieces)
for (const MultiAffineFunction &bPiece : other.pieces)
if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
return false;
return true;
}
void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
assert(hasCompatibleDimensions(piece) &&
"Piece to be added is not compatible with this PWMAFunction!");
assert(piece.isConsistent() && "Piece is internally inconsistent!");
assert(this->getDomain()
.intersect(PresburgerSet(piece.getDomain()))
.isIntegerEmpty() &&
"New piece's domain overlaps with that of existing pieces!");
pieces.push_back(piece);
}
void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
const Matrix &output) {
addPiece(MultiAffineFunction(domain, output));
}
void PWMAFunction::print(raw_ostream &os) const {
os << pieces.size() << " pieces:\n";
for (const MultiAffineFunction &piece : pieces)
piece.print(os);
}
/// The hasCompatibleDimensions functions don't check the number of local ids;
/// functions are still compatible if they have differing number of locals.
bool MultiAffineFunction::hasCompatibleDimensions(
const MultiAffineFunction &f) const {
return getNumDimIds() == f.getNumDimIds() &&
getNumSymbolIds() == f.getNumSymbolIds() &&
getNumOutputs() == f.getNumOutputs();
}
bool PWMAFunction::hasCompatibleDimensions(const MultiAffineFunction &f) const {
return getNumDimIds() == f.getNumDimIds() &&
getNumSymbolIds() == f.getNumSymbolIds() &&
getNumOutputs() == f.getNumOutputs();
}
bool PWMAFunction::hasCompatibleDimensions(const PWMAFunction &f) const {
return getNumDimIds() == f.getNumDimIds() &&
getNumSymbolIds() == f.getNumSymbolIds() &&
getNumOutputs() == f.getNumOutputs();
}

View File

@ -3,6 +3,7 @@ add_mlir_unittest(MLIRPresburgerTests
LinearTransformTest.cpp
MatrixTest.cpp
PresburgerSetTest.cpp
PWMAFunctionTest.cpp
SimplexTest.cpp
../../Dialect/Affine/Analysis/AffineStructuresParser.cpp
)

View File

@ -0,0 +1,183 @@
//===- PWMAFunctionTest.cpp - Tests for PWMAFunction ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains tests for PWMAFunction.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "../../Dialect/Affine/Analysis/AffineStructuresParser.h"
#include "mlir/Analysis/Presburger/PresburgerSet.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
namespace mlir {
using testing::ElementsAre;
/// Parses an IntegerPolyhedron from a StringRef. It is expected that the
/// string represents a valid IntegerSet, otherwise it will violate a gtest
/// assertion.
static IntegerPolyhedron parsePoly(StringRef str, MLIRContext *context) {
FailureOr<IntegerPolyhedron> poly = parseIntegerSetToFAC(str, context);
EXPECT_TRUE(succeeded(poly));
return *poly;
}
static Matrix makeMatrix(unsigned numRow, unsigned numColumns,
ArrayRef<SmallVector<int64_t, 8>> matrix) {
Matrix results(numRow, numColumns);
assert(matrix.size() == numRow);
for (unsigned i = 0; i < numRow; ++i) {
assert(matrix[i].size() == numColumns &&
"Output expression has incorrect dimensionality!");
for (unsigned j = 0; j < numColumns; ++j)
results(i, j) = matrix[i][j];
}
return results;
}
/// Construct a PWMAFunction given the dimensionalities and an array describing
/// the list of pieces. Each piece is given by a string describing the domain
/// and a 2D array that represents the output.
static PWMAFunction parsePWMAF(
unsigned numInputs, unsigned numOutputs,
ArrayRef<std::pair<StringRef, SmallVector<SmallVector<int64_t, 8>, 8>>>
data,
unsigned numSymbols = 0) {
static MLIRContext context;
PWMAFunction result(numInputs - numSymbols, numSymbols, numOutputs);
for (const auto &pair : data) {
IntegerPolyhedron domain = parsePoly(pair.first, &context);
result.addPiece(
domain, makeMatrix(numOutputs, domain.getNumIds() + 1, pair.second));
}
return result;
}
TEST(PWAFunctionTest, isEqual) {
MLIRContext context;
// The output expressions are different but it doesn't matter because they are
// equal in this domain.
PWMAFunction idAtZeros = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/2,
{
{"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
{"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
{"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y).
});
PWMAFunction idAtZeros2 = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/2,
{
{"(x, y) : (y == 0)", {{1, 0, 0}, {0, 20, 0}}}, // (x, 20y).
{"(x, y) : (y - 1 >= 0, x == 0)", {{30, 0, 0}, {0, 1, 0}}}, //(30x, y)
{"(x, y) : (-y - 1 > =0, x == 0)", {{30, 0, 0}, {0, 1, 0}}} //(30x, y)
});
EXPECT_TRUE(idAtZeros.isEqual(idAtZeros2));
PWMAFunction notIdAtZeros = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/2,
{
{"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
{"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}}, // (x, 2y)
{"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}}, // (x, 2y)
});
EXPECT_FALSE(idAtZeros.isEqual(notIdAtZeros));
// These match at their intersection but one has a bigger domain.
PWMAFunction idNoNegNegQuadrant = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/2,
{
{"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
{"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y).
});
PWMAFunction idOnlyPosX =
parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2,
{
{"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
});
EXPECT_FALSE(idNoNegNegQuadrant.isEqual(idOnlyPosX));
// Different representations of the same domain.
PWMAFunction sumPlusOne = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/1,
{
{"(x, y) : (x >= 0)", {{1, 1, 1}}}, // x + y + 1.
{"(x, y) : (-x - 1 >= 0, -y - 1 >= 0)", {{1, 1, 1}}}, // x + y + 1.
{"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 1, 1}}} // x + y + 1.
});
PWMAFunction sumPlusOne2 =
parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1,
{
{"(x, y) : ()", {{1, 1, 1}}}, // x + y + 1.
});
EXPECT_TRUE(sumPlusOne.isEqual(sumPlusOne2));
// Functions with zero input dimensions.
PWMAFunction noInputs1 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1,
{
{"() : ()", {{1}}}, // 1.
});
PWMAFunction noInputs2 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1,
{
{"() : ()", {{2}}}, // 1.
});
EXPECT_TRUE(noInputs1.isEqual(noInputs1));
EXPECT_FALSE(noInputs1.isEqual(noInputs2));
// Mismatched dimensionalities.
EXPECT_FALSE(noInputs1.isEqual(sumPlusOne));
EXPECT_FALSE(idOnlyPosX.isEqual(sumPlusOne));
// Divisions.
// Domain is only multiples of 6; x = 6k for some k.
// x + 4(x/2) + 4(x/3) == 26k.
PWMAFunction mul2AndMul3 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (x - 2*(x floordiv 2) == 0, x - 3*(x floordiv 3) == 0)",
{{1, 4, 4, 0}}}, // x + 4(x/2) + 4(x/3).
});
PWMAFunction mul6 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (x - 6*(x floordiv 6) == 0)", {{0, 26, 0}}}, // 26(x/6).
});
EXPECT_TRUE(mul2AndMul3.isEqual(mul6));
PWMAFunction mul6diff = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 52, 0}}}, // 52(x/6).
});
EXPECT_FALSE(mul2AndMul3.isEqual(mul6diff));
PWMAFunction mul5 = parsePWMAF(
/*numInputs=*/1, /*numOutputs=*/1,
{
{"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 26, 0}}}, // 26(x/5).
});
EXPECT_FALSE(mul2AndMul3.isEqual(mul5));
}
TEST(PWMAFunction, valueAt) {
PWMAFunction nonNegPWAF = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/2,
{
{"(x, y) : (x >= 0)", {{1, 2, 3}, {3, 4, 5}}}, // (x, y).
{"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y)
});
EXPECT_THAT(*nonNegPWAF.valueAt({2, 3}), ElementsAre(11, 23));
EXPECT_THAT(*nonNegPWAF.valueAt({-2, 3}), ElementsAre(11, 23));
EXPECT_THAT(*nonNegPWAF.valueAt({2, -3}), ElementsAre(-1, -1));
EXPECT_FALSE(nonNegPWAF.valueAt({-2, -3}).hasValue());
}
} // namespace mlir