mirror of https://github.com/tracel-ai/burn.git
Add reduce sum onnx ops to burn imports (#1723)
This commit is contained in:
parent
0b919b6a58
commit
fb13503fa9
|
@ -143,7 +143,7 @@ represent the corresponding Burn Op.
|
|||
| [ReduceMean][136] | ✅ | ✅ |
|
||||
| [ReduceMin][137] | ❌ | ✅ |
|
||||
| [ReduceProd][138] | ❌ | ✅ |
|
||||
| [ReduceSum][139] | ❌ | ✅ |
|
||||
| [ReduceSum][139] | ✅ | ✅ |
|
||||
| [ReduceSumSquare][140] | ❌ | ❌ |
|
||||
| [Relu][141] | ✅ | ✅ |
|
||||
| [Reshape][142] | ✅ | ✅ |
|
||||
|
|
|
@ -42,6 +42,8 @@ fn main() {
|
|||
.input("tests/prelu/prelu.onnx")
|
||||
.input("tests/reduce_max/reduce_max.onnx")
|
||||
.input("tests/reduce_mean/reduce_mean.onnx")
|
||||
.input("tests/reduce_sum/reduce_sum_opset13.onnx")
|
||||
.input("tests/reduce_sum/reduce_sum_opset11.onnx")
|
||||
.input("tests/reshape/reshape.onnx")
|
||||
.input("tests/shape/shape.onnx")
|
||||
.input("tests/sigmoid/sigmoid.onnx")
|
||||
|
|
|
@ -51,6 +51,8 @@ include_models!(
|
|||
recip,
|
||||
reduce_max,
|
||||
reduce_mean,
|
||||
reduce_sum_opset13,
|
||||
reduce_sum_opset11,
|
||||
relu,
|
||||
reshape,
|
||||
shape,
|
||||
|
@ -545,6 +547,38 @@ mod tests {
|
|||
assert_eq!(output_value.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_sum_opset11() {
|
||||
let device = Default::default();
|
||||
let model: reduce_sum_opset11::Model<Backend> = reduce_sum_opset11::Model::new(&device);
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device);
|
||||
let (output_scalar, output_tensor, output_value) = model.forward(input.clone());
|
||||
let expected_scalar = Data::from([39.]);
|
||||
let expected = Data::from([[[[39.]]]]);
|
||||
|
||||
assert_eq!(output_scalar.to_data(), expected_scalar);
|
||||
assert_eq!(output_tensor.to_data(), input.to_data());
|
||||
assert_eq!(output_value.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_sum_opset13() {
|
||||
let device = Default::default();
|
||||
let model: reduce_sum_opset13::Model<Backend> = reduce_sum_opset13::Model::new(&device);
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device);
|
||||
let (output_scalar, output_tensor, output_value) = model.forward(input.clone());
|
||||
let expected_scalar = Data::from([39.]);
|
||||
let expected = Data::from([[[[39.]]]]);
|
||||
|
||||
assert_eq!(output_scalar.to_data(), expected_scalar);
|
||||
assert_eq!(output_tensor.to_data(), input.to_data());
|
||||
assert_eq!(output_value.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reshape() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/reduce_sum/reduce_sum.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return (
|
||||
# ReduceSum, keepdims=0, axes=None
|
||||
torch.sum(x),
|
||||
# ReduceSum, keepdims=1, axes=[1]
|
||||
torch.sum(x, dim=1, keepdim=True),
|
||||
# ReduceSum, keepdims=1, axes=[-1]
|
||||
torch.sum(x, dim=-1, keepdim=True),
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)
|
||||
|
||||
torch.onnx.export(model, test_input, "reduce_sum_opset11.onnx", verbose=False, opset_version=11)
|
||||
torch.onnx.export(model, test_input, "reduce_sum_opset13.onnx", verbose=False, opset_version=13)
|
||||
|
||||
print("Finished exporting model")
|
||||
|
||||
# Output some test data for use in the test
|
||||
print(f"Test input data: {test_input}")
|
||||
output = model.forward(*test_input)
|
||||
print(f"Test output data: {output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Binary file not shown.
Binary file not shown.
|
@ -34,6 +34,7 @@ pub enum UnaryNodeKind {
|
|||
Not,
|
||||
ReduceMax,
|
||||
ReduceMean,
|
||||
ReduceSum,
|
||||
Reciprocal,
|
||||
Relu,
|
||||
Shape,
|
||||
|
@ -62,6 +63,7 @@ impl UnaryNodeKind {
|
|||
Self::Not => "not",
|
||||
Self::ReduceMax => "reduce_max",
|
||||
Self::ReduceMean => "reduce_mean",
|
||||
Self::ReduceSum => "reduce_sum",
|
||||
Self::Reciprocal => "reciprocal",
|
||||
Self::Relu => "relu",
|
||||
Self::Shape => "shape",
|
||||
|
@ -355,6 +357,36 @@ impl UnaryNode {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn reduce_sum(input: Type, output: Type, dim: Option<usize>) -> Self {
|
||||
if let Type::Tensor(ref tensor) = output {
|
||||
if let Some(dim) = dim {
|
||||
if tensor.kind == TensorKind::Bool {
|
||||
// Sum is only implemented on numeric tensors
|
||||
panic!("ReduceSum is not supported for boolean");
|
||||
}
|
||||
|
||||
// ReduceSum, keepdims=1, axes=[dim]
|
||||
let dim = dim.to_tokens();
|
||||
Self::new(
|
||||
input,
|
||||
output,
|
||||
UnaryNodeKind::ReduceSum,
|
||||
Rc::new(move |input| quote! { #input.sum_dim(#dim) }),
|
||||
)
|
||||
} else {
|
||||
// ReduceSum, keepdims=0, axes=None
|
||||
Self::new(
|
||||
input,
|
||||
output,
|
||||
UnaryNodeKind::ReduceSum,
|
||||
Rc::new(move |input| quote! { #input.sum() }),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
panic!("ReduceSum only supports tensor output");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn shape(input: Type, output: Type, start_dim: usize, end_dim: usize) -> Self {
|
||||
// Shape as defined by the ONNX op should return a tensor because other ops
|
||||
// (e.g., Gather) will be used on a tensor
|
||||
|
@ -634,6 +666,43 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_reduce_sum() {
|
||||
one_node_graph(
|
||||
UnaryNode::reduce_sum(
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
Some(1),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.sum_dim(1);
|
||||
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
|
||||
one_node_graph(
|
||||
UnaryNode::reduce_sum(
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 1)),
|
||||
None,
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 1> {
|
||||
let tensor2 = tensor1.sum();
|
||||
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_reciprocal() {
|
||||
one_node_graph(
|
||||
|
|
|
@ -45,6 +45,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
|
|||
NodeType::Reciprocal => same_as_input(node),
|
||||
NodeType::ReduceMax => reduce_max_update_outputs(node),
|
||||
NodeType::ReduceMean => reduce_mean_update_outputs(node),
|
||||
NodeType::ReduceSum => reduce_sum_update_outputs(node),
|
||||
NodeType::Relu => same_as_input(node),
|
||||
NodeType::Reshape => reshape_update_outputs(node),
|
||||
NodeType::Shape => shape_update_outputs(node),
|
||||
|
@ -461,6 +462,44 @@ fn reduce_max_update_outputs(node: &mut Node) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Infers the shape of a ReduceSum node and replaces the shape of the output tensor.
|
||||
fn reduce_sum_update_outputs(node: &mut Node) {
|
||||
let node_input = &mut node.inputs[0];
|
||||
let tensor = match node_input.clone().ty {
|
||||
ArgType::Tensor(tensor) => tensor,
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
};
|
||||
|
||||
let dim_only = match node.attrs.get("axes") {
|
||||
Some(value) => match &value {
|
||||
AttributeValue::Int64(_) => true,
|
||||
AttributeValue::Int64s(ints) => ints.len() == 1,
|
||||
_ => false,
|
||||
},
|
||||
None => false,
|
||||
};
|
||||
|
||||
let dim_only = match node.inputs.get(1).and_then(|arg| arg.value.as_ref()) {
|
||||
Some(value) => match &value {
|
||||
Data::Int64(_) => true,
|
||||
Data::Int64s(ints) => ints.len() == 1,
|
||||
_ => false,
|
||||
},
|
||||
None => dim_only,
|
||||
};
|
||||
|
||||
if dim_only {
|
||||
node.outputs[0].ty = ArgType::Tensor(tensor);
|
||||
} else {
|
||||
// NOTE: ReduceSum w/o keepdims reduces to a scalar value, but Burn doesn't have
|
||||
// 0-dim tensor so we can't track or perform other ops on that value if we call
|
||||
// `.into_scalar()` on the result of `tensor.sum()`
|
||||
// node.outputs[0].ty = ArgType::Scalar(tensor.elem_type);
|
||||
// Instead, we return a tensor of rank 1 (the result of `tensor.sum()`)
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor });
|
||||
}
|
||||
}
|
||||
|
||||
fn where_update_outputs(node: &mut Node) {
|
||||
match (
|
||||
node.inputs[0].ty.clone(),
|
||||
|
|
|
@ -17,7 +17,7 @@ use super::ir::{ArgType, Argument, Node, NodeType};
|
|||
|
||||
use protobuf::Message;
|
||||
|
||||
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 7] = [
|
||||
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 8] = [
|
||||
NodeType::BatchNormalization,
|
||||
NodeType::Clip,
|
||||
NodeType::Conv1d,
|
||||
|
@ -25,6 +25,7 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 7] = [
|
|||
NodeType::Dropout,
|
||||
NodeType::Reshape,
|
||||
NodeType::Unsqueeze,
|
||||
NodeType::ReduceSum,
|
||||
];
|
||||
|
||||
#[derive(Debug)]
|
||||
|
|
|
@ -798,6 +798,60 @@ pub fn reduce_mean_config(node: &Node) -> Option<usize> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn reduce_sum_config(node: &Node) -> Option<usize> {
|
||||
let mut axes = Vec::new();
|
||||
let mut keepdims = 1;
|
||||
|
||||
let tensor = match node.inputs.first().unwrap().clone().ty {
|
||||
ArgType::Tensor(tensor) => tensor,
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
};
|
||||
|
||||
// Extract the attributes
|
||||
for (key, value) in node.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"keepdims" => keepdims = value.clone().into_i64(),
|
||||
"axes" => axes = value.clone().into_i64s(),
|
||||
// TODO: handle noop_with_empty_axes
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Handle case where axes are passed in. Will require its own ReduceSumNode instead of a UnaryNode.
|
||||
if let Some(value) = node
|
||||
.inputs
|
||||
.get(1)
|
||||
.and_then(|argument| argument.value.as_ref())
|
||||
{
|
||||
axes = value.clone().into_i64s();
|
||||
}
|
||||
|
||||
if axes.len() > 1 {
|
||||
panic!("ReduceMean: reducing on multiple dimensions is not supported")
|
||||
}
|
||||
|
||||
if axes.is_empty() && keepdims == 1 {
|
||||
panic!("ReduceMean: axes must be provided with keepdims")
|
||||
}
|
||||
|
||||
if !axes.is_empty() && keepdims == 0 {
|
||||
// Not supported in Burn
|
||||
panic!("ReduceMean: the reduce operation must preserve the reduced dimension")
|
||||
}
|
||||
|
||||
if axes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut dim = axes[0];
|
||||
|
||||
if dim < 0 {
|
||||
// Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim
|
||||
dim += tensor.dim as i64;
|
||||
}
|
||||
Some(dim as usize)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shape_config(curr: &Node) -> (usize, usize) {
|
||||
if curr.inputs.len() != 1 {
|
||||
panic!(
|
||||
|
|
|
@ -265,6 +265,7 @@ impl OnnxGraph {
|
|||
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
|
||||
NodeType::ReduceMax => graph.register(Self::reduce_max_conversion(node)),
|
||||
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
|
||||
NodeType::ReduceSum => graph.register(Self::reduce_sum_conversion(node)),
|
||||
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
|
||||
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
|
||||
NodeType::Shape => graph.register(Self::shape_conversion(node)),
|
||||
|
@ -501,6 +502,14 @@ impl OnnxGraph {
|
|||
UnaryNode::reduce_mean(input, output, dim)
|
||||
}
|
||||
|
||||
fn reduce_sum_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let dim = reduce_sum_config(&node);
|
||||
|
||||
UnaryNode::reduce_sum(input, output, dim)
|
||||
}
|
||||
|
||||
fn shape_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
|
|
Loading…
Reference in New Issue