mirror of https://github.com/tracel-ai/burn.git
feat: added slice onnx import (#1856)
* feat: added slice onnx import * fix: axes, steps handling
This commit is contained in:
parent
dd60446946
commit
671ec8c679
|
@ -171,7 +171,7 @@ represent the corresponding Burn Op.
|
||||||
| [Sin][164] | ✅ | ✅ |
|
| [Sin][164] | ✅ | ✅ |
|
||||||
| [Sinh][165] | ❌ | ❌ |
|
| [Sinh][165] | ❌ | ❌ |
|
||||||
| [Size][166] | ❌ | ❌ |
|
| [Size][166] | ❌ | ❌ |
|
||||||
| [Slice][167] | ❌ | ✅ |
|
| [Slice][167] | ✅ | ✅ |
|
||||||
| [Softmax][168] | ✅ | ✅ |
|
| [Softmax][168] | ✅ | ✅ |
|
||||||
| [SoftmaxCrossEntropyLoss][169] | ❌ | ❌ |
|
| [SoftmaxCrossEntropyLoss][169] | ❌ | ❌ |
|
||||||
| [Softplus][170] | ❌ | ❌ |
|
| [Softplus][170] | ❌ | ❌ |
|
||||||
|
|
|
@ -69,6 +69,7 @@ fn main() {
|
||||||
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
|
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
|
||||||
.input("tests/pow/pow.onnx")
|
.input("tests/pow/pow.onnx")
|
||||||
.input("tests/pow/pow_int.onnx")
|
.input("tests/pow/pow_int.onnx")
|
||||||
|
.input("tests/slice/slice.onnx")
|
||||||
.input("tests/sum/sum.onnx")
|
.input("tests/sum/sum.onnx")
|
||||||
.input("tests/sum/sum_int.onnx")
|
.input("tests/sum/sum_int.onnx")
|
||||||
.input("tests/unsqueeze/unsqueeze.onnx")
|
.input("tests/unsqueeze/unsqueeze.onnx")
|
||||||
|
|
|
@ -71,6 +71,7 @@ include_models!(
|
||||||
sigmoid,
|
sigmoid,
|
||||||
sign,
|
sign,
|
||||||
sin,
|
sin,
|
||||||
|
slice,
|
||||||
softmax,
|
softmax,
|
||||||
sqrt,
|
sqrt,
|
||||||
sub_int,
|
sub_int,
|
||||||
|
@ -459,6 +460,24 @@ mod tests {
|
||||||
assert!(expected_sum_2d.approx_eq(output_sum_2d, (1.0e-4, 2)));
|
assert!(expected_sum_2d.approx_eq(output_sum_2d, (1.0e-4, 2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn slice() {
|
||||||
|
let model: slice::Model<Backend> = slice::Model::default();
|
||||||
|
let device = Default::default();
|
||||||
|
|
||||||
|
let input = Tensor::<Backend, 2>::from_floats(
|
||||||
|
[
|
||||||
|
[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.],
|
||||||
|
[11., 12., 13., 14., 15., 16., 17., 18., 19., 20.],
|
||||||
|
],
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
let output = model.forward(input);
|
||||||
|
let expected = Data::from([[1., 2., 3., 4., 5.]]);
|
||||||
|
|
||||||
|
assert_eq!(output.to_data(), expected);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn softmax() {
|
fn softmax() {
|
||||||
// Initialize the model without weights (because the exported file does not contain them)
|
// Initialize the model without weights (because the exported file does not contain them)
|
||||||
|
|
Binary file not shown.
|
@ -0,0 +1,101 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# used to generate model: onnx-tests/tests/slice/slice.onnx
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
from onnx import helper, TensorProto
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
# Starts
|
||||||
|
starts_val = [0,0] # Example shape value
|
||||||
|
starts_tensor = helper.make_tensor(
|
||||||
|
name="starts",
|
||||||
|
data_type=TensorProto.INT64,
|
||||||
|
dims=[len(starts_val)],
|
||||||
|
vals=starts_val,
|
||||||
|
)
|
||||||
|
starts_node = helper.make_node(
|
||||||
|
"Constant",
|
||||||
|
name="starts_constant",
|
||||||
|
inputs=[],
|
||||||
|
outputs=["starts"],
|
||||||
|
value=starts_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ends
|
||||||
|
ends_val = [1,5] # Example shape value
|
||||||
|
ends_tensor = helper.make_tensor(
|
||||||
|
name="ends",
|
||||||
|
data_type=TensorProto.INT64,
|
||||||
|
dims=[len(ends_val)],
|
||||||
|
vals=ends_val,
|
||||||
|
)
|
||||||
|
ends_node = helper.make_node(
|
||||||
|
"Constant",
|
||||||
|
name="ends_constant",
|
||||||
|
inputs=[],
|
||||||
|
outputs=["ends"],
|
||||||
|
value=ends_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Axes
|
||||||
|
axes_val = [0,1] # Example shape value
|
||||||
|
axes_tensor = helper.make_tensor(
|
||||||
|
name="axes",
|
||||||
|
data_type=TensorProto.INT64,
|
||||||
|
dims=[len(axes_val)],
|
||||||
|
vals=axes_val,
|
||||||
|
)
|
||||||
|
axes_node = helper.make_node(
|
||||||
|
"Constant",
|
||||||
|
name="axes_constant",
|
||||||
|
inputs=[],
|
||||||
|
outputs=["axes"],
|
||||||
|
value=axes_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Steps
|
||||||
|
steps_val = [1, 1] # Example shape value
|
||||||
|
steps_tensor = helper.make_tensor(
|
||||||
|
name="steps",
|
||||||
|
data_type=TensorProto.INT64,
|
||||||
|
dims=[len(steps_val)],
|
||||||
|
vals=steps_val,
|
||||||
|
)
|
||||||
|
steps_node = helper.make_node(
|
||||||
|
"Constant",
|
||||||
|
name="steps_constant",
|
||||||
|
inputs=[],
|
||||||
|
outputs=["steps"],
|
||||||
|
value=steps_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define the Slice node that uses the outputs from the constant nodes
|
||||||
|
slice_node = helper.make_node(
|
||||||
|
"Slice",
|
||||||
|
name="slice_node",
|
||||||
|
inputs=["input_tensor", "starts", "ends", "axes", "steps"],
|
||||||
|
outputs=["output"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the graph
|
||||||
|
graph_def = helper.make_graph(
|
||||||
|
nodes=[starts_node, ends_node, axes_node, steps_node, slice_node],
|
||||||
|
name="SliceGraph",
|
||||||
|
inputs=[
|
||||||
|
helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [2, 10]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 5])
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the model
|
||||||
|
model_def = helper.make_model(graph_def, producer_name="slice")
|
||||||
|
|
||||||
|
# Save the model to a file
|
||||||
|
onnx.save(model_def, "slice.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -7,7 +7,7 @@ use super::{
|
||||||
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
|
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
|
||||||
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
|
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
|
||||||
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
|
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
|
||||||
reshape::ReshapeNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
|
reshape::ReshapeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
|
||||||
unsqueeze::UnsqueezeNode,
|
unsqueeze::UnsqueezeNode,
|
||||||
};
|
};
|
||||||
use crate::burn::{BurnImports, Scope, Type};
|
use crate::burn::{BurnImports, Scope, Type};
|
||||||
|
@ -102,6 +102,7 @@ pub enum Node<PS: PrecisionSettings> {
|
||||||
MaxPool2d(MaxPool2dNode),
|
MaxPool2d(MaxPool2dNode),
|
||||||
Range(RangeNode),
|
Range(RangeNode),
|
||||||
Reshape(ReshapeNode),
|
Reshape(ReshapeNode),
|
||||||
|
Slice(SliceNode),
|
||||||
Squeeze(SqueezeNode),
|
Squeeze(SqueezeNode),
|
||||||
Sum(SumNode),
|
Sum(SumNode),
|
||||||
Unary(UnaryNode),
|
Unary(UnaryNode),
|
||||||
|
@ -139,6 +140,7 @@ macro_rules! match_all {
|
||||||
Node::MaxPool2d(node) => $func(node),
|
Node::MaxPool2d(node) => $func(node),
|
||||||
Node::Range(node) => $func(node),
|
Node::Range(node) => $func(node),
|
||||||
Node::Reshape(node) => $func(node),
|
Node::Reshape(node) => $func(node),
|
||||||
|
Node::Slice(node) => $func(node),
|
||||||
Node::Squeeze(node) => $func(node),
|
Node::Squeeze(node) => $func(node),
|
||||||
Node::Sum(node) => $func(node),
|
Node::Sum(node) => $func(node),
|
||||||
Node::Unary(node) => $func(node),
|
Node::Unary(node) => $func(node),
|
||||||
|
@ -186,6 +188,7 @@ impl<PS: PrecisionSettings> Node<PS> {
|
||||||
Node::MaxPool2d(_) => "max_pool2d",
|
Node::MaxPool2d(_) => "max_pool2d",
|
||||||
Node::Range(_) => "range",
|
Node::Range(_) => "range",
|
||||||
Node::Reshape(_) => "reshape",
|
Node::Reshape(_) => "reshape",
|
||||||
|
Node::Slice(_) => "slice",
|
||||||
Node::Squeeze(_) => "squeeze",
|
Node::Squeeze(_) => "squeeze",
|
||||||
Node::Sum(_) => "add",
|
Node::Sum(_) => "add",
|
||||||
Node::Unary(unary) => unary.kind.as_str(),
|
Node::Unary(unary) => unary.kind.as_str(),
|
||||||
|
|
|
@ -27,6 +27,7 @@ pub(crate) mod random_normal;
|
||||||
pub(crate) mod random_uniform;
|
pub(crate) mod random_uniform;
|
||||||
pub(crate) mod range;
|
pub(crate) mod range;
|
||||||
pub(crate) mod reshape;
|
pub(crate) mod reshape;
|
||||||
|
pub(crate) mod slice;
|
||||||
pub(crate) mod squeeze;
|
pub(crate) mod squeeze;
|
||||||
pub(crate) mod sum;
|
pub(crate) mod sum;
|
||||||
pub(crate) mod unary;
|
pub(crate) mod unary;
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
use super::{Node, NodeCodegen};
|
||||||
|
use crate::burn::{Scope, TensorType, Type};
|
||||||
|
use burn::record::PrecisionSettings;
|
||||||
|
use proc_macro2::TokenStream;
|
||||||
|
use quote::quote;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, new)]
|
||||||
|
pub struct SliceNode {
|
||||||
|
pub input: TensorType,
|
||||||
|
pub output: TensorType,
|
||||||
|
pub starts: Vec<usize>,
|
||||||
|
pub ends: Vec<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<PS: PrecisionSettings> NodeCodegen<PS> for SliceNode {
|
||||||
|
fn output_types(&self) -> Vec<Type> {
|
||||||
|
vec![Type::Tensor(self.output.clone())]
|
||||||
|
}
|
||||||
|
fn input_types(&self) -> Vec<Type> {
|
||||||
|
vec![Type::Tensor(self.input.clone())]
|
||||||
|
}
|
||||||
|
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||||
|
let input = scope.tensor_use_owned(&self.input, node_position);
|
||||||
|
let output = &self.output.name;
|
||||||
|
let starts = &self.starts;
|
||||||
|
let ends = &self.ends;
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
let #output = #input.slice([#(#starts..#ends),*]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn into_node(self) -> Node<PS> {
|
||||||
|
Node::Slice(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use burn::record::FullPrecisionSettings;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::burn::{
|
||||||
|
graph::BurnGraph,
|
||||||
|
node::{slice::SliceNode, test::assert_tokens},
|
||||||
|
TensorType,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_codegen_slice() {
|
||||||
|
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||||
|
graph.register(SliceNode::new(
|
||||||
|
TensorType::new_float("tensor1", 4),
|
||||||
|
TensorType::new_float("tensor2", 4),
|
||||||
|
vec![0, 0, 0, 0],
|
||||||
|
vec![1, 1, 1, 1],
|
||||||
|
));
|
||||||
|
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
|
||||||
|
|
||||||
|
let expected = quote! {
|
||||||
|
use burn::{
|
||||||
|
module::Module,
|
||||||
|
tensor::{backend::Backend, Tensor},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct Model<B: Backend> {
|
||||||
|
phantom: core::marker::PhantomData<B>,
|
||||||
|
device: burn::module::Ignored<B::Device>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Model <B> {
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
pub fn new(device: &B::Device) -> Self {
|
||||||
|
Self {
|
||||||
|
phantom: core::marker::PhantomData,
|
||||||
|
device: burn::module::Ignored(device.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||||
|
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||||
|
let tensor2 = tensor1.slice([0usize..1usize,0usize..1usize,0usize..1usize,0usize..1usize]);
|
||||||
|
|
||||||
|
tensor2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_tokens(graph.codegen(), expected);
|
||||||
|
}
|
||||||
|
}
|
|
@ -63,6 +63,7 @@ pub fn dim_inference(node: &mut Node) {
|
||||||
NodeType::Sigmoid => same_as_input(node),
|
NodeType::Sigmoid => same_as_input(node),
|
||||||
NodeType::Sign => same_as_input(node),
|
NodeType::Sign => same_as_input(node),
|
||||||
NodeType::Sin => same_as_input(node),
|
NodeType::Sin => same_as_input(node),
|
||||||
|
NodeType::Slice => slice_update_outputs(node),
|
||||||
NodeType::Softmax => same_as_input(node),
|
NodeType::Softmax => same_as_input(node),
|
||||||
NodeType::Sqrt => same_as_input(node),
|
NodeType::Sqrt => same_as_input(node),
|
||||||
NodeType::Sub => same_as_input(node),
|
NodeType::Sub => same_as_input(node),
|
||||||
|
@ -423,6 +424,33 @@ fn squeeze_update_output(node: &mut Node) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn slice_update_outputs(node: &mut Node) {
|
||||||
|
let shape = match &node.inputs[1].value {
|
||||||
|
Some(value) => match value {
|
||||||
|
Data::Int64s(shape) => Some(shape.clone()),
|
||||||
|
_ => panic!("Slice: invalid input types"),
|
||||||
|
},
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
if shape.is_none() {
|
||||||
|
panic!("Slice: invalid shape");
|
||||||
|
}
|
||||||
|
|
||||||
|
let output = match &node.outputs[0].ty {
|
||||||
|
ArgType::Tensor(tensor) => tensor.clone(),
|
||||||
|
_ => panic!("Slice: invalid output types"),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(shape) = shape {
|
||||||
|
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||||
|
dim: shape.len(),
|
||||||
|
shape: None, // shape is calculated at runtime
|
||||||
|
..output
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Update the output tensor dimension based on the "axes" attribute or the second input
|
/// Update the output tensor dimension based on the "axes" attribute or the second input
|
||||||
fn unsqueeze_update_output(node: &mut Node) {
|
fn unsqueeze_update_output(node: &mut Node) {
|
||||||
let axes = if node.inputs.len() == 2 {
|
let axes = if node.inputs.len() == 2 {
|
||||||
|
|
|
@ -18,7 +18,7 @@ use super::ir::{ArgType, Argument, Node, NodeType};
|
||||||
|
|
||||||
use protobuf::Message;
|
use protobuf::Message;
|
||||||
|
|
||||||
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
|
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 11] = [
|
||||||
NodeType::BatchNormalization,
|
NodeType::BatchNormalization,
|
||||||
NodeType::Clip,
|
NodeType::Clip,
|
||||||
NodeType::Conv1d,
|
NodeType::Conv1d,
|
||||||
|
@ -28,6 +28,7 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
|
||||||
NodeType::Reshape,
|
NodeType::Reshape,
|
||||||
NodeType::Unsqueeze,
|
NodeType::Unsqueeze,
|
||||||
NodeType::ReduceSum,
|
NodeType::ReduceSum,
|
||||||
|
NodeType::Slice,
|
||||||
NodeType::Squeeze,
|
NodeType::Squeeze,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
|
|
@ -1015,6 +1015,67 @@ pub fn shape_config(curr: &Node) -> (usize, usize) {
|
||||||
(start_dim as usize, end_dim as usize)
|
(start_dim as usize, end_dim as usize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn slice_config(node: &Node) -> (Vec<usize>, Vec<usize>) {
|
||||||
|
let start_value = &node.inputs[1].value;
|
||||||
|
let end_value = &node.inputs[2].value;
|
||||||
|
|
||||||
|
let starts = match &node.inputs[1].ty {
|
||||||
|
ArgType::Tensor(tensor) => {
|
||||||
|
assert_eq!(tensor.dim, 1, "Slice: ends tensor must be 1D");
|
||||||
|
if let Some(Data::Int64s(shape)) = start_value.as_ref() {
|
||||||
|
shape
|
||||||
|
.iter()
|
||||||
|
.map(|x| {
|
||||||
|
assert!(*x >= 0, "Slice: start must be positive");
|
||||||
|
*x as usize
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
panic!("Tensor data type must be int64")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("Only tensor input is valid for shape"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let ends = match &node.inputs[2].ty {
|
||||||
|
ArgType::Tensor(tensor) => {
|
||||||
|
assert_eq!(tensor.dim, 1, "Slice: ends tensor must be 1D");
|
||||||
|
if let Some(Data::Int64s(shape)) = end_value.as_ref() {
|
||||||
|
shape
|
||||||
|
.iter()
|
||||||
|
.map(|x| {
|
||||||
|
assert!(*x >= 0, "Slice: end must be positive");
|
||||||
|
*x as usize
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
panic!("Tensor data type must be int64")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("Only tensor input is valid for shape"),
|
||||||
|
};
|
||||||
|
|
||||||
|
for (key, value) in node.attrs.iter() {
|
||||||
|
match key.as_str() {
|
||||||
|
"axes" => {
|
||||||
|
let mut i = 0;
|
||||||
|
value.clone().into_i64s().iter().for_each(|x| {
|
||||||
|
assert_eq!(*x, i, "Slice: axes must be consecutive");
|
||||||
|
i += 1;
|
||||||
|
})
|
||||||
|
}
|
||||||
|
"steps" => value.clone().into_i64s().into_iter().for_each(|x| {
|
||||||
|
if x != 1 {
|
||||||
|
panic!("Slice: steps other than 1 are not supported");
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(starts, ends)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn transpose_config(curr: &Node) -> Vec<i64> {
|
pub fn transpose_config(curr: &Node) -> Vec<i64> {
|
||||||
if curr.inputs.len() != 1 {
|
if curr.inputs.len() != 1 {
|
||||||
panic!(
|
panic!(
|
||||||
|
|
|
@ -42,6 +42,7 @@ use crate::{
|
||||||
random_uniform::RandomUniformNode,
|
random_uniform::RandomUniformNode,
|
||||||
range::RangeNode,
|
range::RangeNode,
|
||||||
reshape::ReshapeNode,
|
reshape::ReshapeNode,
|
||||||
|
slice::SliceNode,
|
||||||
squeeze::SqueezeNode,
|
squeeze::SqueezeNode,
|
||||||
sum::SumNode,
|
sum::SumNode,
|
||||||
unary::UnaryNode,
|
unary::UnaryNode,
|
||||||
|
@ -294,6 +295,7 @@ impl OnnxGraph {
|
||||||
NodeType::Shape => graph.register(Self::shape_conversion(node)),
|
NodeType::Shape => graph.register(Self::shape_conversion(node)),
|
||||||
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
|
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
|
||||||
NodeType::Sin => graph.register(Self::sin_conversion(node)),
|
NodeType::Sin => graph.register(Self::sin_conversion(node)),
|
||||||
|
NodeType::Slice => graph.register(Self::slice_conversion(node)),
|
||||||
NodeType::Sum => graph.register(Self::sum_conversion(node)),
|
NodeType::Sum => graph.register(Self::sum_conversion(node)),
|
||||||
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
|
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
|
||||||
NodeType::Concat => graph.register(Self::concat_conversion(node)),
|
NodeType::Concat => graph.register(Self::concat_conversion(node)),
|
||||||
|
@ -686,6 +688,14 @@ impl OnnxGraph {
|
||||||
UnaryNode::sin(input, output)
|
UnaryNode::sin(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn slice_conversion(node: Node) -> SliceNode {
|
||||||
|
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||||
|
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||||
|
let (starts, ends) = slice_config(&node);
|
||||||
|
|
||||||
|
SliceNode::new(input, output, starts, ends)
|
||||||
|
}
|
||||||
|
|
||||||
fn sum_conversion(node: Node) -> SumNode {
|
fn sum_conversion(node: Node) -> SumNode {
|
||||||
let inputs = node
|
let inputs = node
|
||||||
.inputs
|
.inputs
|
||||||
|
|
Loading…
Reference in New Issue