Add reduce max ONNX op support (#1636)

* Add reduce max onnx op support

* Fix comments on tensor rank 1 result
This commit is contained in:
Guillaume Lagrange 2024-04-17 08:26:46 -04:00 committed by GitHub
parent 2d264e9a74
commit 424033283a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 226 additions and 4 deletions

View File

@ -139,7 +139,7 @@ represent the corresponding Burn Op.
| [ReduceL][132] | ❌ | ❌ |
| [ReduceLogSum][133] | ❌ | ❌ |
| [ReduceLogSumExp][134] | ❌ | ❌ |
| [ReduceMax][135] | | ✅ |
| [ReduceMax][135] | | ✅ |
| [ReduceMean][136] | ✅ | ✅ |
| [ReduceMin][137] | ❌ | ✅ |
| [ReduceProd][138] | ❌ | ✅ |

View File

@ -37,6 +37,7 @@ fn main() {
.input("tests/recip/recip.onnx")
.input("tests/relu/relu.onnx")
.input("tests/leaky_relu/leaky_relu.onnx")
.input("tests/reduce_max/reduce_max.onnx")
.input("tests/reduce_mean/reduce_mean.onnx")
.input("tests/reshape/reshape.onnx")
.input("tests/shape/shape.onnx")

View File

@ -45,6 +45,7 @@ include_models!(
neg,
not,
recip,
reduce_max,
reduce_mean,
relu,
reshape,
@ -449,6 +450,22 @@ mod tests {
output3.to_data().assert_approx_eq(&expected3, 3);
}
#[test]
fn reduce_max() {
let device = Default::default();
let model: reduce_max::Model<Backend> = reduce_max::Model::new(&device);
// Run the model
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device);
let (output_scalar, output_tensor, output_value) = model.forward(input.clone());
let expected_scalar = Data::from([25.]);
let expected = Data::from([[[[25.]]]]);
assert_eq!(output_scalar.to_data(), expected_scalar);
assert_eq!(output_tensor.to_data(), input.to_data());
assert_eq!(output_value.to_data(), expected);
}
#[test]
fn reduce_mean() {
let device = Default::default();

View File

@ -0,0 +1,46 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/reduce_max/reduce_max.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return (
# ReduceMax, keepdims=0, axes=None
torch.max(x),
# ReduceMax, keepdims=1, axes=[1]
torch.max(x, dim=1, keepdim=True).values,
# ReduceMax, keepdims=1, axes=[-1]
torch.max(x, dim=-1, keepdim=True).values,
)
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "reduce_max.onnx"
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16)
print(f"Finished exporting model to {onnx_name}")
# Output some test data for use in the test
print(f"Test input data: {test_input}")
output = model.forward(*test_input)
print(f"Test output data: {output}")
if __name__ == "__main__":
main()

View File

@ -27,13 +27,14 @@ pub enum UnaryNodeKind {
Exp,
Flatten,
Gelu,
LeakyRelu,
Log,
LogSoftmax,
Neg,
Not,
ReduceMax,
ReduceMean,
Reciprocal,
LeakyRelu,
Relu,
Shape,
Sigmoid,
@ -53,13 +54,14 @@ impl UnaryNodeKind {
Self::Exp => "exp",
Self::Flatten => "flatten",
Self::Gelu => "gelu",
Self::LeakyRelu => "leaky_relu",
Self::Log => "log",
Self::LogSoftmax => "log_softmax",
Self::Neg => "neg",
Self::Not => "not",
Self::ReduceMax => "reduce_max",
Self::ReduceMean => "reduce_mean",
Self::Reciprocal => "reciprocal",
Self::LeakyRelu => "leaky_relu",
Self::Relu => "relu",
Self::Shape => "shape",
Self::Sigmoid => "sigmoid",
@ -294,6 +296,36 @@ impl UnaryNode {
}
}
pub(crate) fn reduce_max(input: Type, output: Type, dim: Option<usize>) -> Self {
if let Type::Tensor(ref tensor) = output {
if let Some(dim) = dim {
if tensor.kind == TensorKind::Bool {
// Max is only implemented on numeric tensors
panic!("ReduceMax is not supported for boolean");
}
// ReduceMax, keepdims=1, axes=[dim]
let dim = dim.to_tokens();
Self::new(
input,
output,
UnaryNodeKind::ReduceMax,
Rc::new(move |input| quote! { #input.max_dim(#dim) }),
)
} else {
// ReduceMax, keepdims=0, axes=None
Self::new(
input,
output,
UnaryNodeKind::ReduceMax,
Rc::new(move |input| quote! { #input.max() }),
)
}
} else {
panic!("ReduceMax only supports tensor output");
}
}
pub(crate) fn reduce_mean(input: Type, output: Type, dim: Option<usize>) -> Self {
// ReduceMean is constrained to numeric tensors, so no need to check for bool.
if let Type::Tensor(_) = output {
@ -519,6 +551,43 @@ mod tests {
);
}
#[test]
fn test_unary_codegen_reduce_max() {
one_node_graph(
UnaryNode::reduce_max(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
Some(1),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.max_dim(1);
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
one_node_graph(
UnaryNode::reduce_max(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 1)),
None,
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 1> {
let tensor2 = tensor1.max();
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
#[test]
fn test_unary_codegen_reduce_mean() {
one_node_graph(

View File

@ -40,6 +40,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::Neg => same_as_input(node),
NodeType::Not => same_as_input(node),
NodeType::Reciprocal => same_as_input(node),
NodeType::ReduceMax => reduce_max_update_outputs(node),
NodeType::ReduceMean => reduce_mean_update_outputs(node),
NodeType::Relu => same_as_input(node),
NodeType::Reshape => reshape_update_outputs(node),
@ -239,8 +240,10 @@ fn reduce_mean_update_outputs(node: &mut Node) {
node.outputs[0].ty = ArgType::Tensor(tensor);
} else {
// NOTE: ReduceMean w/o keepdims reduces to a scalar value, but Burn doesn't have
// 0-dim tensor so we can't track or perform other ops on that value
// 0-dim tensor so we can't track or perform other ops on that value if we call
// `.into_scalar()` on the result of `tensor.max()`
// node.outputs[0].ty = ArgType::Scalar(tensor.elem_type);
// Instead, we return a tensor of rank 1 (the result of `tensor.max()`)
node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor });
}
}
@ -438,3 +441,36 @@ fn conv_transpose2d_update_outputs(node: &mut Node) {
panic!("Only tensor input is valid");
}
}
/// Infers the shape of a ReduceMax node and replaces the shape of the output tensor.
fn reduce_max_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("ReduceMax: multiple inputs are not supported");
}
let node_input = &mut node.inputs[0];
let tensor = match node_input.clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};
let dim_only = match node.attrs.get("axes") {
Some(value) => match &value {
AttributeValue::Int64(_) => true,
AttributeValue::Int64s(ints) => ints.len() == 1,
_ => false,
},
None => false,
};
if dim_only {
node.outputs[0].ty = ArgType::Tensor(tensor);
} else {
// NOTE: ReduceMax w/o keepdims reduces to a scalar value, but Burn doesn't have
// 0-dim tensor so we can't track or perform other ops on that value if we call
// `.into_scalar()` on the result of `tensor.max()`
// node.outputs[0].ty = ArgType::Scalar(tensor.elem_type);
// Instead, we return a tensor of rank 1 (the result of `tensor.max()`)
node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor });
}
}

View File

@ -661,6 +661,50 @@ fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d {
}
}
pub fn reduce_max_config(node: &Node) -> Option<usize> {
let mut axes = Vec::new();
let mut keepdims = 1;
let tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};
// Extract the attributes
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axes" => axes = value.clone().into_i64s(),
"keepdims" => keepdims = value.clone().into_i64(),
_ => {}
}
}
if axes.len() > 1 {
panic!("ReduceMax: reducing on multiple dimensions is not supported")
}
if axes.is_empty() && keepdims == 1 {
panic!("ReduceMax: axes must be provided with keepdims")
}
if !axes.is_empty() && keepdims == 0 {
// Not supported in Burn
panic!("ReduceMax: the reduce operation must preserve the reduced dimension")
}
if axes.is_empty() {
None
} else {
let mut dim = axes[0];
if dim < 0 {
// Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim
dim += tensor.dim as i64;
}
Some(dim as usize)
}
}
pub fn reduce_mean_config(node: &Node) -> Option<usize> {
let mut axes = Vec::new();
let mut keepdims = 1;

View File

@ -253,6 +253,7 @@ impl OnnxGraph {
NodeType::Sqrt => graph.register(Self::sqrt_conversion(node)),
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
NodeType::ReduceMax => graph.register(Self::reduce_max_conversion(node)),
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
@ -467,6 +468,14 @@ impl OnnxGraph {
ReshapeNode::new(input, output, shape)
}
fn reduce_max_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let dim = reduce_max_config(&node);
UnaryNode::reduce_max(input, output, dim)
}
fn reduce_mean_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();