mirror of https://github.com/tracel-ai/burn.git
feat: Less + LessOrEqual onnx import (#1800)
This commit is contained in:
parent
e39b4d2da0
commit
1f31e20ce8
|
@ -89,8 +89,8 @@ represent the corresponding Burn Op.
|
|||
| [IsNaN][81] | ❌ | ❌ |
|
||||
| [LayerNormalization][82] | ✅ | ✅ |
|
||||
| [LeakyRelu][83] | ✅ | ✅ |
|
||||
| [Less][84] | ❌ | ✅ |
|
||||
| [LessOrEqual][85] | ❌ | ✅ |
|
||||
| [Less][84] | ✅ | ✅ |
|
||||
| [LessOrEqual][85] | ✅ | ✅ |
|
||||
| Linear | ✅ | ✅ |
|
||||
| [Log][87] | ✅ | ✅ |
|
||||
| [LogSoftmax][88] | ✅ | ✅ |
|
||||
|
|
|
@ -40,6 +40,8 @@ fn main() {
|
|||
.input("tests/mul/mul.onnx")
|
||||
.input("tests/neg/neg.onnx")
|
||||
.input("tests/not/not.onnx")
|
||||
.input("tests/less/less.onnx")
|
||||
.input("tests/less_or_equal/less_or_equal.onnx")
|
||||
.input("tests/recip/recip.onnx")
|
||||
.input("tests/relu/relu.onnx")
|
||||
.input("tests/leaky_relu/leaky_relu.onnx")
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
pytorch2.3.0:<3A>
|
||||
,
|
||||
onnx::Less_0
|
||||
onnx::Less_12/Less"Less
|
||||
main_graphZ
|
||||
onnx::Less_0
|
||||
|
||||
|
||||
Z
|
||||
onnx::Less_1
|
||||
|
||||
|
||||
b
|
||||
2
|
||||
|
||||
|
||||
B
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/less/less.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.lt(x,y)
|
||||
|
||||
def main():
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.set_printoptions(precision=8)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
|
||||
onnx_name = "less.onnx"
|
||||
|
||||
test_input1 = torch.randn(4, 4, device=device)
|
||||
test_input2 = torch.randn(4, 4, device=device)
|
||||
torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
print("Test input data: {} {}".format(test_input1, test_input2))
|
||||
output = model.forward(test_input1, test_input2)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,17 @@
|
|||
pytorch2.3.0:¹
|
||||
H
|
||||
onnx::LessOrEqual_0
|
||||
onnx::LessOrEqual_12/LessOrEqual"LessOrEqual
|
||||
main_graphZ%
|
||||
onnx::LessOrEqual_0
|
||||
|
||||
|
||||
Z%
|
||||
onnx::LessOrEqual_1
|
||||
|
||||
|
||||
b
|
||||
2
|
||||
|
||||
|
||||
B
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/less_or_equal/less_or_equal.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.le(x,y)
|
||||
|
||||
def main():
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.set_printoptions(precision=8)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
|
||||
onnx_name = "less_or_equal.onnx"
|
||||
|
||||
test_input1 = torch.randn(4, 4, device=device)
|
||||
test_input2 = torch.randn(4, 4, device=device)
|
||||
torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
print("Test input data: {} {}".format(test_input1, test_input2))
|
||||
output = model.forward(test_input1, test_input2)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -51,6 +51,8 @@ include_models!(
|
|||
mul,
|
||||
neg,
|
||||
not,
|
||||
less,
|
||||
less_or_equal,
|
||||
prelu,
|
||||
recip,
|
||||
reduce_max,
|
||||
|
@ -1171,6 +1173,32 @@ mod tests {
|
|||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn less() {
|
||||
let device = Default::default();
|
||||
let model: less::Model<Backend> = less::Model::new(&device);
|
||||
|
||||
let input1 = Tensor::<Backend, 2>::from_floats([[1.0, 4.0, 9.0, 25.0]], &device);
|
||||
let input2 = Tensor::<Backend, 2>::from_floats([[1.0, 5.0, 8.0, -25.0]], &device);
|
||||
|
||||
let output = model.forward(input1, input2);
|
||||
let expected = Data::from([[false, true, false, false]]);
|
||||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn less_or_equal() {
|
||||
let device = Default::default();
|
||||
let model: less_or_equal::Model<Backend> = less_or_equal::Model::new(&device);
|
||||
|
||||
let input1 = Tensor::<Backend, 2>::from_floats([[1.0, 4.0, 9.0, 25.0]], &device);
|
||||
let input2 = Tensor::<Backend, 2>::from_floats([[1.0, 5.0, 8.0, -25.0]], &device);
|
||||
|
||||
let output = model.forward(input1, input2);
|
||||
let expected = Data::from([[true, true, false, false]]);
|
||||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_creation_with_a_default_device() {
|
||||
let device = Default::default();
|
||||
|
|
|
@ -16,6 +16,8 @@ pub enum BinaryType {
|
|||
Powi,
|
||||
Min,
|
||||
Max,
|
||||
Less,
|
||||
LessOrEqual,
|
||||
}
|
||||
|
||||
impl BinaryType {
|
||||
|
@ -30,6 +32,8 @@ impl BinaryType {
|
|||
BinaryType::Powf => "powf",
|
||||
BinaryType::Min => "min_pair",
|
||||
BinaryType::Max => "max_pair",
|
||||
BinaryType::Less => "lower",
|
||||
BinaryType::LessOrEqual => "lower_equal",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -193,6 +197,28 @@ impl BinaryNode {
|
|||
};
|
||||
Self::new(lhs, rhs, output, BinaryType::Max, Arc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn lower(lhs: Type, rhs: Type, output: Type) -> Self {
|
||||
let function = match (&lhs, &rhs) {
|
||||
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.lower(#rhs) },
|
||||
_ => panic!("lower is supported for tensor only"),
|
||||
};
|
||||
Self::new(lhs, rhs, output, BinaryType::Less, Arc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn lower_equal(lhs: Type, rhs: Type, output: Type) -> Self {
|
||||
let function = match (&lhs, &rhs) {
|
||||
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.lower_equal(#rhs) },
|
||||
_ => panic!("lower_equal is supported for tensor only"),
|
||||
};
|
||||
Self::new(
|
||||
lhs,
|
||||
rhs,
|
||||
output,
|
||||
BinaryType::LessOrEqual,
|
||||
Arc::new(function),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -358,6 +384,16 @@ mod tests {
|
|||
test_binary_operator_on_tensors!(max_pair);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_less() {
|
||||
test_binary_operator_on_tensors!(lower);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_less_or_equal() {
|
||||
test_binary_operator_on_tensors!(lower_equal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_equal_tensors() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
|
|
@ -46,6 +46,8 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
|
|||
NodeType::Mul => same_as_input(node),
|
||||
NodeType::Neg => same_as_input(node),
|
||||
NodeType::Not => same_as_input(node),
|
||||
NodeType::Less => less_update_outputs(node),
|
||||
NodeType::LessOrEqual => less_or_equal_update_outputs(node),
|
||||
NodeType::Reciprocal => same_as_input(node),
|
||||
NodeType::ReduceMax => reduce_max_update_outputs(node),
|
||||
NodeType::ReduceMean => reduce_mean_update_outputs(node),
|
||||
|
@ -237,6 +239,30 @@ fn reshape_update_outputs(node: &mut Node) {
|
|||
}
|
||||
}
|
||||
|
||||
fn less_update_outputs(node: &mut Node) {
|
||||
match &node.inputs[0].ty {
|
||||
ArgType::Tensor(tensor) => {
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
elem_type: ElementType::Bool,
|
||||
..tensor.clone()
|
||||
});
|
||||
}
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
}
|
||||
}
|
||||
|
||||
fn less_or_equal_update_outputs(node: &mut Node) {
|
||||
match &node.inputs[0].ty {
|
||||
ArgType::Tensor(tensor) => {
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
elem_type: ElementType::Bool,
|
||||
..tensor.clone()
|
||||
});
|
||||
}
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
}
|
||||
}
|
||||
|
||||
fn reduce_mean_update_outputs(node: &mut Node) {
|
||||
if node.inputs.len() != 1 {
|
||||
panic!("Mean: multiple inputs are not supported");
|
||||
|
|
|
@ -251,6 +251,8 @@ impl OnnxGraph {
|
|||
NodeType::MatMul => graph.register(Self::matmul_conversion(node)),
|
||||
NodeType::Neg => graph.register(Self::neg_conversion(node)),
|
||||
NodeType::Not => graph.register(Self::not_conversion(node)),
|
||||
NodeType::Less => graph.register(Self::less_conversion(node)),
|
||||
NodeType::LessOrEqual => graph.register(Self::less_or_equal_conversion(node)),
|
||||
NodeType::LayerNormalization => {
|
||||
graph.register(Self::layer_norm_conversion::<PS>(node))
|
||||
}
|
||||
|
@ -822,6 +824,22 @@ impl OnnxGraph {
|
|||
UnaryNode::not(input, output)
|
||||
}
|
||||
|
||||
fn less_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
|
||||
BinaryNode::lower(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn less_or_equal_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
|
||||
BinaryNode::lower_equal(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn pow_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
|
|
Loading…
Reference in New Issue