[mlir][python] support taking ops instead of values in op constructors
Introduce support for accepting ops instead of values when constructing ops. A single-result op can be used instead of a value, including in lists of values, and any op can be used instead of a list of values. This is similar to, but more powerful, than the C++ API that allows for implicitly casting an OpType to Value if it is statically known to have a single result - the cast in Python is based on the op dynamically having a single result, and also handles the multi-result case. This allows to build IR in a more concise way: op = dialect.produce_multiple_results() other = dialect.produce_single_result() dialect.consume_multiple_results(other, op) instead of having to access the results manually op = dialect.produce.multiple_results() other = dialect.produce_single_result() dialect.consume_multiple_results(other.result, op.operation.results) The dispatch is implemented directly in Python and is triggered automatically for autogenerated OpView subclasses. Extension OpView classes should use the functions provided in ods_common.py if they want to implement this behavior. An alternative could be to implement the dispatch in the C++ bindings code, but it would require to forward opaque types through all Python functions down to a binding call, which makes it hard to inspect them in Python, e.g., to obtain the types of values. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D111306
This commit is contained in:
parent
cb879d00d8
commit
b164f23c29
|
@ -10,6 +10,7 @@ try:
|
|||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
|
||||
def isa(cls: Type, ty: Type):
|
||||
try:
|
||||
|
@ -26,11 +27,12 @@ class FillOp:
|
|||
results = []
|
||||
if isa(RankedTensorType, output.type):
|
||||
results = [output.type]
|
||||
op = self.build_generic(results=results,
|
||||
operands=[value, output],
|
||||
attributes=None,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
op = self.build_generic(
|
||||
results=results,
|
||||
operands=[_get_op_result_or_value(o) for o in [value, output]],
|
||||
attributes=None,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
OpView.__init__(self, op)
|
||||
linalgDialect = Context.current.get_dialect_descriptor("linalg")
|
||||
fill_builtin_region(linalgDialect, self.operation)
|
||||
|
|
|
@ -5,11 +5,14 @@
|
|||
# Provide a convenient name for sub-packages to resolve the main C-extension
|
||||
# with a relative import.
|
||||
from .._mlir_libs import _mlir as _cext
|
||||
from typing import Sequence as _Sequence, Union as _Union
|
||||
|
||||
__all__ = [
|
||||
"equally_sized_accessor",
|
||||
"extend_opview_class",
|
||||
"get_default_loc_context",
|
||||
"get_op_result_or_value",
|
||||
"get_op_results_or_values",
|
||||
"segmented_accessor",
|
||||
]
|
||||
|
||||
|
@ -118,3 +121,38 @@ def get_default_loc_context(location=None):
|
|||
# Location.current raises ValueError if there is no current location.
|
||||
return _cext.ir.Location.current.context
|
||||
return location.context
|
||||
|
||||
|
||||
def get_op_result_or_value(
|
||||
arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]
|
||||
) -> _cext.ir.Value:
|
||||
"""Returns the given value or the single result of the given op.
|
||||
|
||||
This is useful to implement op constructors so that they can take other ops as
|
||||
arguments instead of requiring the caller to extract results for every op.
|
||||
Raises ValueError if provided with an op that doesn't have a single result.
|
||||
"""
|
||||
if isinstance(arg, _cext.ir.OpView):
|
||||
return arg.operation.result
|
||||
elif isinstance(arg, _cext.ir.Operation):
|
||||
return arg.result
|
||||
else:
|
||||
assert isinstance(arg, _cext.ir.Value)
|
||||
return arg
|
||||
|
||||
|
||||
def get_op_results_or_values(
|
||||
arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _Sequence[_cext.ir.Value]]
|
||||
) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
|
||||
"""Returns the given sequence of values or the results of the given op.
|
||||
|
||||
This is useful to implement op constructors so that they can take other ops as
|
||||
lists of arguments instead of requiring the caller to extract results for
|
||||
every op.
|
||||
"""
|
||||
if isinstance(arg, _cext.ir.OpView):
|
||||
return arg.operation.results
|
||||
elif isinstance(arg, _cext.ir.Operation):
|
||||
return arg.results
|
||||
else:
|
||||
return arg
|
||||
|
|
|
@ -7,8 +7,8 @@ try:
|
|||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Any, Sequence
|
||||
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
|
||||
|
||||
class ForOp:
|
||||
"""Specialization for the SCF for op class."""
|
||||
|
@ -17,7 +17,8 @@ class ForOp:
|
|||
lower_bound,
|
||||
upper_bound,
|
||||
step,
|
||||
iter_args: Sequence[Any] = [],
|
||||
iter_args: Optional[Union[Operation, OpView,
|
||||
Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
|
@ -26,14 +27,22 @@ class ForOp:
|
|||
- `lower_bound` is the value to use as lower bound of the loop.
|
||||
- `upper_bound` is the value to use as upper bound of the loop.
|
||||
- `step` is the value to use as loop step.
|
||||
- `iter_args` is a list of additional loop-carried arguments.
|
||||
- `iter_args` is a list of additional loop-carried arguments or an operation
|
||||
producing them as results.
|
||||
"""
|
||||
if iter_args is None:
|
||||
iter_args = []
|
||||
iter_args = _get_op_results_or_values(iter_args)
|
||||
|
||||
results = [arg.type for arg in iter_args]
|
||||
super().__init__(
|
||||
self.build_generic(
|
||||
regions=1,
|
||||
results=results,
|
||||
operands=[lower_bound, upper_bound, step] + list(iter_args),
|
||||
operands=[
|
||||
_get_op_result_or_value(o)
|
||||
for o in [lower_bound, upper_bound, step]
|
||||
] + list(iter_args),
|
||||
loc=loc,
|
||||
ip=ip))
|
||||
self.regions[0].blocks.append(IndexType.get(), *results)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
from contextlib import contextmanager
|
||||
import functools
|
||||
|
@ -10,12 +10,15 @@ import inspect
|
|||
import threading
|
||||
|
||||
from ..... import ir
|
||||
from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
|
||||
from .comprehension import *
|
||||
from .config import *
|
||||
from .emitter import *
|
||||
|
||||
_CONTEXT = threading.local()
|
||||
|
||||
StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList,
|
||||
Sequence[Union[ir.Value, ir.Operation, ir.OpView]]]
|
||||
|
||||
@contextmanager
|
||||
def bind_op_def(model: LinalgOpDef):
|
||||
|
@ -37,6 +40,15 @@ def current_op_def() -> LinalgOpDef:
|
|||
"but none is set. Did you mean to call this in an op definition?")
|
||||
|
||||
|
||||
def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList:
|
||||
if isinstance(outs, (ir.Operation, ir.OpView)):
|
||||
return _get_op_results_or_values(outs)
|
||||
elif isinstance(outs, ir.OpResultList):
|
||||
return outs
|
||||
|
||||
return [_get_op_result_or_value(o) for o in outs]
|
||||
|
||||
|
||||
class DefinedOpCallable:
|
||||
"""Callable that wraps any defined op function."""
|
||||
|
||||
|
@ -44,7 +56,8 @@ class DefinedOpCallable:
|
|||
self.op_name = op_name
|
||||
self.model = model
|
||||
|
||||
def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs):
|
||||
def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value],
|
||||
outs: StructuredOpOuts, **kwargs):
|
||||
"""Emits the corresponding op definition as IR.
|
||||
|
||||
Most arguments are passed through to the underlying emitter. The following
|
||||
|
@ -73,17 +86,19 @@ class DefinedOpCallable:
|
|||
emit_generic or not ctx.is_registered_operation(fully_qualified_name))
|
||||
|
||||
op_config = op_configs[0]
|
||||
out_values = _prepare_structured_op_outs(outs)
|
||||
in_values = [_get_op_result_or_value(i) for i in ins]
|
||||
if op_config.structured_op:
|
||||
if emit_generic:
|
||||
return emit_generic_structured_op(
|
||||
op_config.structured_op, *ins, outs=outs, **kwargs)
|
||||
op_config.structured_op, *in_values, outs=out_values, **kwargs)
|
||||
else:
|
||||
return emit_named_structured_op(
|
||||
op_config.structured_op,
|
||||
self.op_name,
|
||||
self.model.metadata.cpp_class_name,
|
||||
*ins,
|
||||
outs=outs,
|
||||
*in_values,
|
||||
outs=out_values,
|
||||
**kwargs)
|
||||
|
||||
raise NotImplementedError(
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Dict, Sequence
|
||||
from typing import Dict, List, Sequence, Tuple, Union
|
||||
|
||||
from .....ir import *
|
||||
from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region
|
||||
|
@ -10,6 +10,7 @@ from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region
|
|||
from .... import linalg
|
||||
from .... import std
|
||||
from .... import math
|
||||
from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
|
||||
|
||||
from .scalar_expr import *
|
||||
from .config import *
|
||||
|
@ -18,8 +19,10 @@ import numpy as np
|
|||
__all__ = [
|
||||
"emit_generic_structured_op",
|
||||
"emit_named_structured_op",
|
||||
"ValueList",
|
||||
]
|
||||
|
||||
ValueList = Union[Sequence[Value], OpResultList]
|
||||
|
||||
def isa(cls: Type, ty: Type):
|
||||
try:
|
||||
|
@ -30,17 +33,18 @@ def isa(cls: Type, ty: Type):
|
|||
|
||||
|
||||
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
||||
*ins: Value, outs: Sequence[Value],
|
||||
*ins: Value, outs: ValueList,
|
||||
**attrs: Sequence[int]):
|
||||
all_arg_defs = op_config.ordered_operands
|
||||
in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"]
|
||||
out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"]
|
||||
attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"]
|
||||
|
||||
# Verify outs is a sequence.
|
||||
if not isinstance(outs, Sequence):
|
||||
raise ValueError(f"Expected named argument outs to have type Sequence "
|
||||
f"but got {type(outs)}")
|
||||
# Verify outs is a sequence or a list of results.
|
||||
if not isinstance(outs, (Sequence, OpResultList)):
|
||||
raise ValueError(
|
||||
f"Expected named argument outs to have type Sequence or OpResultLis but got {type(outs)}"
|
||||
)
|
||||
|
||||
# Arity validation.
|
||||
if len(ins) != len(in_arg_defs):
|
||||
|
@ -122,7 +126,7 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
|||
|
||||
|
||||
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
|
||||
outs: Sequence[Value], **attrs: Sequence[int]):
|
||||
outs: ValueList, **attrs: Sequence[int]):
|
||||
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
|
||||
indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
|
||||
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
|
||||
|
@ -153,8 +157,8 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
|
|||
|
||||
|
||||
def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
|
||||
op_class_name: str, *ins: Value,
|
||||
outs: Sequence[Value], **attrs: Sequence[int]):
|
||||
op_class_name: str, *ins: Value, outs: ValueList,
|
||||
**attrs: Sequence[int]):
|
||||
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
|
||||
indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
|
||||
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
|
||||
|
@ -355,11 +359,11 @@ class _BodyBuilder:
|
|||
return std.MinUIOp(lhs.type, lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}")
|
||||
|
||||
def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
|
||||
in_arg_defs: Sequence[OperandDefConfig],
|
||||
ins: Sequence[Value],
|
||||
out_arg_defs: Sequence[OperandDefConfig],
|
||||
outs: Sequence[Value]):
|
||||
def _infer_structured_outs(
|
||||
op_config: LinalgStructuredOpConfig,
|
||||
in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value],
|
||||
out_arg_defs: Sequence[OperandDefConfig],
|
||||
outs: Union[Sequence[Value], OpResultList]) -> Tuple[ValueList, List[Type]]:
|
||||
"""Infers implicit outs and output types.
|
||||
|
||||
Respects existing contents of outs if not empty.
|
||||
|
|
|
@ -24,9 +24,9 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
|
|||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: operands.append(variadic1)
|
||||
// CHECK: operands.append(non_variadic)
|
||||
// CHECK: if variadic2 is not None: operands.append(variadic2)
|
||||
// CHECK: operands.append(_get_op_results_or_values(variadic1))
|
||||
// CHECK: operands.append(_get_op_result_or_value(non_variadic))
|
||||
// CHECK: if variadic2 is not None: operands.append(_get_op_result_or_value(variadic2))
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(self.build_generic(
|
||||
// CHECK: attributes=attributes, results=results, operands=operands,
|
||||
|
@ -150,8 +150,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
|
|||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: operands.append(_gen_arg_0)
|
||||
// CHECK: operands.append(_gen_arg_2)
|
||||
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_0))
|
||||
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
|
||||
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
|
||||
// CHECK: _ods_get_default_loc_context(loc))
|
||||
// CHECK: if is_ is not None: attributes["is"] = is_
|
||||
|
@ -197,9 +197,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
|
|||
// CHECK: results.append(i32)
|
||||
// CHECK: results.append(_gen_res_1)
|
||||
// CHECK: results.append(i64)
|
||||
// CHECK: operands.append(_gen_arg_0)
|
||||
// CHECK: operands.append(f32)
|
||||
// CHECK: operands.append(_gen_arg_2)
|
||||
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_0))
|
||||
// CHECK: operands.append(_get_op_result_or_value(f32))
|
||||
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(self.build_generic(
|
||||
// CHECK: attributes=attributes, results=results, operands=operands,
|
||||
|
@ -230,8 +230,8 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
|
|||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: operands.append(non_variadic)
|
||||
// CHECK: operands.extend(variadic)
|
||||
// CHECK: operands.append(_get_op_result_or_value(non_variadic))
|
||||
// CHECK: operands.extend(_get_op_results_or_values(variadic))
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(self.build_generic(
|
||||
// CHECK: attributes=attributes, results=results, operands=operands,
|
||||
|
@ -285,7 +285,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
|
|||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: operands.append(in_)
|
||||
// CHECK: operands.append(_get_op_result_or_value(in_))
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(self.build_generic(
|
||||
// CHECK: attributes=attributes, results=results, operands=operands,
|
||||
|
@ -353,8 +353,8 @@ def SimpleOp : TestOp<"simple"> {
|
|||
// CHECK: attributes = {}
|
||||
// CHECK: results.append(i64)
|
||||
// CHECK: results.append(f64)
|
||||
// CHECK: operands.append(i32)
|
||||
// CHECK: operands.append(f32)
|
||||
// CHECK: operands.append(_get_op_result_or_value(i32))
|
||||
// CHECK: operands.append(_get_op_result_or_value(f32))
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(self.build_generic(
|
||||
// CHECK: attributes=attributes, results=results, operands=operands,
|
||||
|
|
|
@ -185,3 +185,30 @@ def testNamedStructuredAsGenericOp():
|
|||
return linalg.matmul(lhs, rhs, outs=[init_result.result], emit_generic=True)
|
||||
|
||||
print(module)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOpResultFromOtherOp
|
||||
@run
|
||||
def testOpResultFromOtherOp():
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
@builtin.FuncOp.from_py_func(
|
||||
RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
|
||||
f32))
|
||||
def pass_an_op_directly(arg0, arg1):
|
||||
one = std.ConstantOp(F32Type.get(), 1.0)
|
||||
# CHECK: %[[LHS:.*]] = linalg.fill
|
||||
lhs = linalg.FillOp(arg0, one)
|
||||
# CHECK: %[[RHS:.*]] = linalg.fill
|
||||
rhs = linalg.FillOp(arg1, one)
|
||||
# CHECK: %[[INIT:.*]] = linalg.init_tensor
|
||||
init = linalg.InitTensorOp([4, 8], f32)
|
||||
# CHECK: linalg.matmul
|
||||
# CHECK: ins(%[[LHS]], %[[RHS]]
|
||||
# CHECK: outs(%[[INIT]]
|
||||
return linalg.matmul(lhs, rhs, outs=init)
|
||||
|
||||
print(module)
|
||||
|
|
|
@ -2,53 +2,82 @@
|
|||
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import scf
|
||||
from mlir.dialects import std
|
||||
from mlir.dialects import builtin
|
||||
|
||||
|
||||
def run(f):
|
||||
def constructAndPrintInModule(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
f()
|
||||
print(module)
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testSimpleLoop
|
||||
@run
|
||||
@constructAndPrintInModule
|
||||
def testSimpleLoop():
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
index_type = IndexType.get()
|
||||
with InsertionPoint(module.body):
|
||||
index_type = IndexType.get()
|
||||
|
||||
@builtin.FuncOp.from_py_func(index_type, index_type, index_type)
|
||||
def simple_loop(lb, ub, step):
|
||||
loop = scf.ForOp(lb, ub, step, [lb, lb])
|
||||
with InsertionPoint(loop.body):
|
||||
scf.YieldOp(loop.inner_iter_args)
|
||||
return
|
||||
@builtin.FuncOp.from_py_func(index_type, index_type, index_type)
|
||||
def simple_loop(lb, ub, step):
|
||||
loop = scf.ForOp(lb, ub, step, [lb, lb])
|
||||
with InsertionPoint(loop.body):
|
||||
scf.YieldOp(loop.inner_iter_args)
|
||||
return
|
||||
|
||||
# CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
|
||||
# CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
|
||||
# CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
|
||||
# CHECK: scf.yield %[[I1]], %[[I2]]
|
||||
print(module)
|
||||
|
||||
# CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
|
||||
# CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
|
||||
# CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
|
||||
# CHECK: scf.yield %[[I1]], %[[I2]]
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testInductionVar
|
||||
@run
|
||||
@constructAndPrintInModule
|
||||
def testInductionVar():
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
index_type = IndexType.get()
|
||||
with InsertionPoint(module.body):
|
||||
index_type = IndexType.get()
|
||||
|
||||
@builtin.FuncOp.from_py_func(index_type, index_type, index_type)
|
||||
def induction_var(lb, ub, step):
|
||||
loop = scf.ForOp(lb, ub, step, [lb])
|
||||
with InsertionPoint(loop.body):
|
||||
scf.YieldOp([loop.induction_variable])
|
||||
return
|
||||
@builtin.FuncOp.from_py_func(index_type, index_type, index_type)
|
||||
def induction_var(lb, ub, step):
|
||||
loop = scf.ForOp(lb, ub, step, [lb])
|
||||
with InsertionPoint(loop.body):
|
||||
scf.YieldOp([loop.induction_variable])
|
||||
return
|
||||
|
||||
# CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
|
||||
# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
|
||||
# CHECK: scf.yield %[[IV]]
|
||||
print(module)
|
||||
|
||||
# CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
|
||||
# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
|
||||
# CHECK: scf.yield %[[IV]]
|
||||
|
||||
|
||||
@constructAndPrintInModule
|
||||
def testOpsAsArguments():
|
||||
index_type = IndexType.get()
|
||||
callee = builtin.FuncOp(
|
||||
"callee", ([], [index_type, index_type]), visibility="private")
|
||||
func = builtin.FuncOp("ops_as_arguments", ([], []))
|
||||
with InsertionPoint(func.add_entry_block()):
|
||||
lb = std.ConstantOp.create_index(0)
|
||||
ub = std.ConstantOp.create_index(42)
|
||||
step = std.ConstantOp.create_index(2)
|
||||
iter_args = std.CallOp(callee, [])
|
||||
loop = scf.ForOp(lb, ub, step, iter_args)
|
||||
with InsertionPoint(loop.body):
|
||||
scf.YieldOp(loop.inner_iter_args)
|
||||
std.ReturnOp([])
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOpsAsArguments
|
||||
# CHECK: func private @callee() -> (index, index)
|
||||
# CHECK: func @ops_as_arguments() {
|
||||
# CHECK: %[[LB:.*]] = constant 0
|
||||
# CHECK: %[[UB:.*]] = constant 42
|
||||
# CHECK: %[[STEP:.*]] = constant 2
|
||||
# CHECK: %[[ARGS:.*]]:2 = call @callee()
|
||||
# CHECK: scf.for %arg0 = %c0 to %c42 step %c2
|
||||
# CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1)
|
||||
# CHECK: scf.yield %{{.*}}, %{{.*}}
|
||||
# CHECK: return
|
||||
|
|
|
@ -28,7 +28,7 @@ constexpr const char *fileHeader = R"Py(
|
|||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
from ._ods_common import _cext as _ods_cext
|
||||
from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context
|
||||
from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
|
||||
_ods_ir = _ods_cext.ir
|
||||
|
||||
try:
|
||||
|
@ -489,20 +489,25 @@ constexpr const char *initTemplate = R"Py(
|
|||
)Py";
|
||||
|
||||
/// Template for appending a single element to the operand/result list.
|
||||
/// {0} is either 'operand' or 'result';
|
||||
/// {1} is the field name.
|
||||
constexpr const char *singleElementAppendTemplate = "{0}s.append({1})";
|
||||
/// {0} is the field name.
|
||||
constexpr const char *singleOperandAppendTemplate =
|
||||
"operands.append(_get_op_result_or_value({0}))";
|
||||
constexpr const char *singleResultAppendTemplate = "results.append({0})";
|
||||
|
||||
/// Template for appending an optional element to the operand/result list.
|
||||
/// {0} is either 'operand' or 'result';
|
||||
/// {1} is the field name.
|
||||
constexpr const char *optionalAppendTemplate =
|
||||
"if {1} is not None: {0}s.append({1})";
|
||||
/// {0} is the field name.
|
||||
constexpr const char *optionalAppendOperandTemplate =
|
||||
"if {0} is not None: operands.append(_get_op_result_or_value({0}))";
|
||||
constexpr const char *optionalAppendResultTemplate =
|
||||
"if {0} is not None: results.append({0})";
|
||||
|
||||
/// Template for appending a a list of elements to the operand/result list.
|
||||
/// {0} is either 'operand' or 'result';
|
||||
/// {1} is the field name.
|
||||
constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})";
|
||||
/// Template for appending a list of elements to the operand/result list.
|
||||
/// {0} is the field name.
|
||||
constexpr const char *multiOperandAppendTemplate =
|
||||
"operands.extend(_get_op_results_or_values({0}))";
|
||||
constexpr const char *multiOperandAppendPackTemplate =
|
||||
"operands.append(_get_op_results_or_values({0}))";
|
||||
constexpr const char *multiResultAppendTemplate = "results.extend({0})";
|
||||
|
||||
/// Template for setting an attribute in the operation builder.
|
||||
/// {0} is the attribute name;
|
||||
|
@ -625,43 +630,70 @@ static void populateBuilderLinesSuccessors(
|
|||
}
|
||||
|
||||
/// Populates `builderLines` with additional lines that are required in the
|
||||
/// builder. `kind` must be either "operand" or "result". `names` contains the
|
||||
/// names of init arguments that correspond to the elements.
|
||||
static void populateBuilderLines(
|
||||
const Operator &op, const char *kind, llvm::ArrayRef<std::string> names,
|
||||
llvm::SmallVectorImpl<std::string> &builderLines,
|
||||
llvm::function_ref<int(const Operator &)> getNumElements,
|
||||
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
|
||||
getElement) {
|
||||
bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
|
||||
/// builder to set up op operands.
|
||||
static void
|
||||
populateBuilderLinesOperand(const Operator &op,
|
||||
llvm::ArrayRef<std::string> names,
|
||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
||||
bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
|
||||
|
||||
// For each element, find or generate a name.
|
||||
for (int i = 0, e = getNumElements(op); i < e; ++i) {
|
||||
const NamedTypeConstraint &element = getElement(op, i);
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
const NamedTypeConstraint &element = op.getOperand(i);
|
||||
std::string name = names[i];
|
||||
|
||||
// Choose the formatting string based on the element kind.
|
||||
llvm::StringRef formatString;
|
||||
if (!element.isVariableLength()) {
|
||||
formatString = singleElementAppendTemplate;
|
||||
formatString = singleOperandAppendTemplate;
|
||||
} else if (element.isOptional()) {
|
||||
formatString = optionalAppendTemplate;
|
||||
formatString = optionalAppendOperandTemplate;
|
||||
} else {
|
||||
assert(element.isVariadic() && "unhandled element group type");
|
||||
// If emitting with sizedSegments, then we add the actual list typed
|
||||
// element using the singleElementAppendTemplate. Otherwise, we extend
|
||||
// the actual operands.
|
||||
// If emitting with sizedSegments, then we add the actual list-typed
|
||||
// element. Otherwise, we extend the actual operands.
|
||||
if (sizedSegments) {
|
||||
// Append the list as is.
|
||||
formatString = singleElementAppendTemplate;
|
||||
formatString = multiOperandAppendPackTemplate;
|
||||
} else {
|
||||
// Append the list elements.
|
||||
formatString = multiElementAppendTemplate;
|
||||
formatString = multiOperandAppendTemplate;
|
||||
}
|
||||
}
|
||||
|
||||
// Add the lines.
|
||||
builderLines.push_back(llvm::formatv(formatString.data(), kind, name));
|
||||
builderLines.push_back(llvm::formatv(formatString.data(), name));
|
||||
}
|
||||
}
|
||||
|
||||
/// Populates `builderLines` with additional lines that are required in the
|
||||
/// builder to set up op results.
|
||||
static void
|
||||
populateBuilderLinesResult(const Operator &op,
|
||||
llvm::ArrayRef<std::string> names,
|
||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
||||
bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
|
||||
|
||||
// For each element, find or generate a name.
|
||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
const NamedTypeConstraint &element = op.getResult(i);
|
||||
std::string name = names[i];
|
||||
|
||||
// Choose the formatting string based on the element kind.
|
||||
llvm::StringRef formatString;
|
||||
if (!element.isVariableLength()) {
|
||||
formatString = singleResultAppendTemplate;
|
||||
} else if (element.isOptional()) {
|
||||
formatString = optionalAppendResultTemplate;
|
||||
} else {
|
||||
assert(element.isVariadic() && "unhandled element group type");
|
||||
// If emitting with sizedSegments, then we add the actual list-typed
|
||||
// element. Otherwise, we extend the actual operands.
|
||||
if (sizedSegments) {
|
||||
formatString = singleResultAppendTemplate;
|
||||
} else {
|
||||
formatString = multiResultAppendTemplate;
|
||||
}
|
||||
}
|
||||
|
||||
builderLines.push_back(llvm::formatv(formatString.data(), name));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -680,12 +712,10 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
|
|||
op.getNumNativeAttributes() + op.getNumSuccessors());
|
||||
populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
|
||||
|
||||
populateBuilderLines(
|
||||
op, "result",
|
||||
llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
|
||||
builderLines, getNumResults, getResult);
|
||||
populateBuilderLines(op, "operand", operandArgNames, builderLines,
|
||||
getNumOperands, getOperand);
|
||||
populateBuilderLinesResult(
|
||||
op, llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
|
||||
builderLines);
|
||||
populateBuilderLinesOperand(op, operandArgNames, builderLines);
|
||||
populateBuilderLinesAttr(
|
||||
op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
|
||||
builderLines);
|
||||
|
|
Loading…
Reference in New Issue