[mlir][python] support taking ops instead of values in op constructors

Introduce support for accepting ops instead of values when constructing ops. A
single-result op can be used instead of a value, including in lists of values,
and any op can be used instead of a list of values. This is similar to, but
more powerful, than the C++ API that allows for implicitly casting an OpType to
Value if it is statically known to have a single result - the cast in Python is
based on the op dynamically having a single result, and also handles the
multi-result case. This allows to build IR in a more concise way:

    op = dialect.produce_multiple_results()
    other = dialect.produce_single_result()
    dialect.consume_multiple_results(other, op)

instead of having to access the results manually

    op = dialect.produce.multiple_results()
    other = dialect.produce_single_result()
    dialect.consume_multiple_results(other.result, op.operation.results)

The dispatch is implemented directly in Python and is triggered automatically
for autogenerated OpView subclasses. Extension OpView classes should use the
functions provided in ods_common.py if they want to implement this behavior.
An alternative could be to implement the dispatch in the C++ bindings code, but
it would require to forward opaque types through all Python functions down to a
binding call, which makes it hard to inspect them in Python, e.g., to obtain
the types of values.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D111306
This commit is contained in:
Alex Zinenko 2021-10-07 18:29:03 +02:00
parent cb879d00d8
commit b164f23c29
9 changed files with 269 additions and 115 deletions

View File

@ -10,6 +10,7 @@ try:
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
def isa(cls: Type, ty: Type):
try:
@ -26,11 +27,12 @@ class FillOp:
results = []
if isa(RankedTensorType, output.type):
results = [output.type]
op = self.build_generic(results=results,
operands=[value, output],
attributes=None,
loc=loc,
ip=ip)
op = self.build_generic(
results=results,
operands=[_get_op_result_or_value(o) for o in [value, output]],
attributes=None,
loc=loc,
ip=ip)
OpView.__init__(self, op)
linalgDialect = Context.current.get_dialect_descriptor("linalg")
fill_builtin_region(linalgDialect, self.operation)

View File

@ -5,11 +5,14 @@
# Provide a convenient name for sub-packages to resolve the main C-extension
# with a relative import.
from .._mlir_libs import _mlir as _cext
from typing import Sequence as _Sequence, Union as _Union
__all__ = [
"equally_sized_accessor",
"extend_opview_class",
"get_default_loc_context",
"get_op_result_or_value",
"get_op_results_or_values",
"segmented_accessor",
]
@ -118,3 +121,38 @@ def get_default_loc_context(location=None):
# Location.current raises ValueError if there is no current location.
return _cext.ir.Location.current.context
return location.context
def get_op_result_or_value(
arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]
) -> _cext.ir.Value:
"""Returns the given value or the single result of the given op.
This is useful to implement op constructors so that they can take other ops as
arguments instead of requiring the caller to extract results for every op.
Raises ValueError if provided with an op that doesn't have a single result.
"""
if isinstance(arg, _cext.ir.OpView):
return arg.operation.result
elif isinstance(arg, _cext.ir.Operation):
return arg.result
else:
assert isinstance(arg, _cext.ir.Value)
return arg
def get_op_results_or_values(
arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _Sequence[_cext.ir.Value]]
) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
"""Returns the given sequence of values or the results of the given op.
This is useful to implement op constructors so that they can take other ops as
lists of arguments instead of requiring the caller to extract results for
every op.
"""
if isinstance(arg, _cext.ir.OpView):
return arg.operation.results
elif isinstance(arg, _cext.ir.Operation):
return arg.results
else:
return arg

View File

@ -7,8 +7,8 @@ try:
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Any, Sequence
from typing import Any, Optional, Sequence, Union
from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
class ForOp:
"""Specialization for the SCF for op class."""
@ -17,7 +17,8 @@ class ForOp:
lower_bound,
upper_bound,
step,
iter_args: Sequence[Any] = [],
iter_args: Optional[Union[Operation, OpView,
Sequence[Value]]] = None,
*,
loc=None,
ip=None):
@ -26,14 +27,22 @@ class ForOp:
- `lower_bound` is the value to use as lower bound of the loop.
- `upper_bound` is the value to use as upper bound of the loop.
- `step` is the value to use as loop step.
- `iter_args` is a list of additional loop-carried arguments.
- `iter_args` is a list of additional loop-carried arguments or an operation
producing them as results.
"""
if iter_args is None:
iter_args = []
iter_args = _get_op_results_or_values(iter_args)
results = [arg.type for arg in iter_args]
super().__init__(
self.build_generic(
regions=1,
results=results,
operands=[lower_bound, upper_bound, step] + list(iter_args),
operands=[
_get_op_result_or_value(o)
for o in [lower_bound, upper_bound, step]
] + list(iter_args),
loc=loc,
ip=ip))
self.regions[0].blocks.append(IndexType.get(), *results)

View File

@ -2,7 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Dict, List
from typing import Dict, List, Sequence, Union
from contextlib import contextmanager
import functools
@ -10,12 +10,15 @@ import inspect
import threading
from ..... import ir
from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
from .comprehension import *
from .config import *
from .emitter import *
_CONTEXT = threading.local()
StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList,
Sequence[Union[ir.Value, ir.Operation, ir.OpView]]]
@contextmanager
def bind_op_def(model: LinalgOpDef):
@ -37,6 +40,15 @@ def current_op_def() -> LinalgOpDef:
"but none is set. Did you mean to call this in an op definition?")
def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList:
if isinstance(outs, (ir.Operation, ir.OpView)):
return _get_op_results_or_values(outs)
elif isinstance(outs, ir.OpResultList):
return outs
return [_get_op_result_or_value(o) for o in outs]
class DefinedOpCallable:
"""Callable that wraps any defined op function."""
@ -44,7 +56,8 @@ class DefinedOpCallable:
self.op_name = op_name
self.model = model
def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs):
def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value],
outs: StructuredOpOuts, **kwargs):
"""Emits the corresponding op definition as IR.
Most arguments are passed through to the underlying emitter. The following
@ -73,17 +86,19 @@ class DefinedOpCallable:
emit_generic or not ctx.is_registered_operation(fully_qualified_name))
op_config = op_configs[0]
out_values = _prepare_structured_op_outs(outs)
in_values = [_get_op_result_or_value(i) for i in ins]
if op_config.structured_op:
if emit_generic:
return emit_generic_structured_op(
op_config.structured_op, *ins, outs=outs, **kwargs)
op_config.structured_op, *in_values, outs=out_values, **kwargs)
else:
return emit_named_structured_op(
op_config.structured_op,
self.op_name,
self.model.metadata.cpp_class_name,
*ins,
outs=outs,
*in_values,
outs=out_values,
**kwargs)
raise NotImplementedError(

View File

@ -2,7 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Dict, Sequence
from typing import Dict, List, Sequence, Tuple, Union
from .....ir import *
from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region
@ -10,6 +10,7 @@ from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region
from .... import linalg
from .... import std
from .... import math
from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
from .scalar_expr import *
from .config import *
@ -18,8 +19,10 @@ import numpy as np
__all__ = [
"emit_generic_structured_op",
"emit_named_structured_op",
"ValueList",
]
ValueList = Union[Sequence[Value], OpResultList]
def isa(cls: Type, ty: Type):
try:
@ -30,17 +33,18 @@ def isa(cls: Type, ty: Type):
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
*ins: Value, outs: Sequence[Value],
*ins: Value, outs: ValueList,
**attrs: Sequence[int]):
all_arg_defs = op_config.ordered_operands
in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"]
out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"]
attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"]
# Verify outs is a sequence.
if not isinstance(outs, Sequence):
raise ValueError(f"Expected named argument outs to have type Sequence "
f"but got {type(outs)}")
# Verify outs is a sequence or a list of results.
if not isinstance(outs, (Sequence, OpResultList)):
raise ValueError(
f"Expected named argument outs to have type Sequence or OpResultLis but got {type(outs)}"
)
# Arity validation.
if len(ins) != len(in_arg_defs):
@ -122,7 +126,7 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
outs: Sequence[Value], **attrs: Sequence[int]):
outs: ValueList, **attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
@ -153,8 +157,8 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
op_class_name: str, *ins: Value,
outs: Sequence[Value], **attrs: Sequence[int]):
op_class_name: str, *ins: Value, outs: ValueList,
**attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
@ -355,11 +359,11 @@ class _BodyBuilder:
return std.MinUIOp(lhs.type, lhs, rhs).result
raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}")
def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
in_arg_defs: Sequence[OperandDefConfig],
ins: Sequence[Value],
out_arg_defs: Sequence[OperandDefConfig],
outs: Sequence[Value]):
def _infer_structured_outs(
op_config: LinalgStructuredOpConfig,
in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value],
out_arg_defs: Sequence[OperandDefConfig],
outs: Union[Sequence[Value], OpResultList]) -> Tuple[ValueList, List[Type]]:
"""Infers implicit outs and output types.
Respects existing contents of outs if not empty.

View File

@ -24,9 +24,9 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: operands.append(variadic1)
// CHECK: operands.append(non_variadic)
// CHECK: if variadic2 is not None: operands.append(variadic2)
// CHECK: operands.append(_get_op_results_or_values(variadic1))
// CHECK: operands.append(_get_op_result_or_value(non_variadic))
// CHECK: if variadic2 is not None: operands.append(_get_op_result_or_value(variadic2))
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@ -150,8 +150,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: operands.append(_gen_arg_0)
// CHECK: operands.append(_gen_arg_2)
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_0))
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: if is_ is not None: attributes["is"] = is_
@ -197,9 +197,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: results.append(i32)
// CHECK: results.append(_gen_res_1)
// CHECK: results.append(i64)
// CHECK: operands.append(_gen_arg_0)
// CHECK: operands.append(f32)
// CHECK: operands.append(_gen_arg_2)
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_0))
// CHECK: operands.append(_get_op_result_or_value(f32))
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@ -230,8 +230,8 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: operands.append(non_variadic)
// CHECK: operands.extend(variadic)
// CHECK: operands.append(_get_op_result_or_value(non_variadic))
// CHECK: operands.extend(_get_op_results_or_values(variadic))
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@ -285,7 +285,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: operands.append(in_)
// CHECK: operands.append(_get_op_result_or_value(in_))
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@ -353,8 +353,8 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: attributes = {}
// CHECK: results.append(i64)
// CHECK: results.append(f64)
// CHECK: operands.append(i32)
// CHECK: operands.append(f32)
// CHECK: operands.append(_get_op_result_or_value(i32))
// CHECK: operands.append(_get_op_result_or_value(f32))
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,

View File

@ -185,3 +185,30 @@ def testNamedStructuredAsGenericOp():
return linalg.matmul(lhs, rhs, outs=[init_result.result], emit_generic=True)
print(module)
# CHECK-LABEL: TEST: testOpResultFromOtherOp
@run
def testOpResultFromOtherOp():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
f32))
def pass_an_op_directly(arg0, arg1):
one = std.ConstantOp(F32Type.get(), 1.0)
# CHECK: %[[LHS:.*]] = linalg.fill
lhs = linalg.FillOp(arg0, one)
# CHECK: %[[RHS:.*]] = linalg.fill
rhs = linalg.FillOp(arg1, one)
# CHECK: %[[INIT:.*]] = linalg.init_tensor
init = linalg.InitTensorOp([4, 8], f32)
# CHECK: linalg.matmul
# CHECK: ins(%[[LHS]], %[[RHS]]
# CHECK: outs(%[[INIT]]
return linalg.matmul(lhs, rhs, outs=init)
print(module)

View File

@ -2,53 +2,82 @@
from mlir.ir import *
from mlir.dialects import scf
from mlir.dialects import std
from mlir.dialects import builtin
def run(f):
def constructAndPrintInModule(f):
print("\nTEST:", f.__name__)
f()
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f()
print(module)
return f
# CHECK-LABEL: TEST: testSimpleLoop
@run
@constructAndPrintInModule
def testSimpleLoop():
with Context(), Location.unknown():
module = Module.create()
index_type = IndexType.get()
with InsertionPoint(module.body):
index_type = IndexType.get()
@builtin.FuncOp.from_py_func(index_type, index_type, index_type)
def simple_loop(lb, ub, step):
loop = scf.ForOp(lb, ub, step, [lb, lb])
with InsertionPoint(loop.body):
scf.YieldOp(loop.inner_iter_args)
return
@builtin.FuncOp.from_py_func(index_type, index_type, index_type)
def simple_loop(lb, ub, step):
loop = scf.ForOp(lb, ub, step, [lb, lb])
with InsertionPoint(loop.body):
scf.YieldOp(loop.inner_iter_args)
return
# CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
# CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
# CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
# CHECK: scf.yield %[[I1]], %[[I2]]
print(module)
# CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
# CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
# CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
# CHECK: scf.yield %[[I1]], %[[I2]]
# CHECK-LABEL: TEST: testInductionVar
@run
@constructAndPrintInModule
def testInductionVar():
with Context(), Location.unknown():
module = Module.create()
index_type = IndexType.get()
with InsertionPoint(module.body):
index_type = IndexType.get()
@builtin.FuncOp.from_py_func(index_type, index_type, index_type)
def induction_var(lb, ub, step):
loop = scf.ForOp(lb, ub, step, [lb])
with InsertionPoint(loop.body):
scf.YieldOp([loop.induction_variable])
return
@builtin.FuncOp.from_py_func(index_type, index_type, index_type)
def induction_var(lb, ub, step):
loop = scf.ForOp(lb, ub, step, [lb])
with InsertionPoint(loop.body):
scf.YieldOp([loop.induction_variable])
return
# CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
# CHECK: scf.yield %[[IV]]
print(module)
# CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
# CHECK: scf.yield %[[IV]]
@constructAndPrintInModule
def testOpsAsArguments():
index_type = IndexType.get()
callee = builtin.FuncOp(
"callee", ([], [index_type, index_type]), visibility="private")
func = builtin.FuncOp("ops_as_arguments", ([], []))
with InsertionPoint(func.add_entry_block()):
lb = std.ConstantOp.create_index(0)
ub = std.ConstantOp.create_index(42)
step = std.ConstantOp.create_index(2)
iter_args = std.CallOp(callee, [])
loop = scf.ForOp(lb, ub, step, iter_args)
with InsertionPoint(loop.body):
scf.YieldOp(loop.inner_iter_args)
std.ReturnOp([])
# CHECK-LABEL: TEST: testOpsAsArguments
# CHECK: func private @callee() -> (index, index)
# CHECK: func @ops_as_arguments() {
# CHECK: %[[LB:.*]] = constant 0
# CHECK: %[[UB:.*]] = constant 42
# CHECK: %[[STEP:.*]] = constant 2
# CHECK: %[[ARGS:.*]]:2 = call @callee()
# CHECK: scf.for %arg0 = %c0 to %c42 step %c2
# CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1)
# CHECK: scf.yield %{{.*}}, %{{.*}}
# CHECK: return

View File

@ -28,7 +28,7 @@ constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context
from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
_ods_ir = _ods_cext.ir
try:
@ -489,20 +489,25 @@ constexpr const char *initTemplate = R"Py(
)Py";
/// Template for appending a single element to the operand/result list.
/// {0} is either 'operand' or 'result';
/// {1} is the field name.
constexpr const char *singleElementAppendTemplate = "{0}s.append({1})";
/// {0} is the field name.
constexpr const char *singleOperandAppendTemplate =
"operands.append(_get_op_result_or_value({0}))";
constexpr const char *singleResultAppendTemplate = "results.append({0})";
/// Template for appending an optional element to the operand/result list.
/// {0} is either 'operand' or 'result';
/// {1} is the field name.
constexpr const char *optionalAppendTemplate =
"if {1} is not None: {0}s.append({1})";
/// {0} is the field name.
constexpr const char *optionalAppendOperandTemplate =
"if {0} is not None: operands.append(_get_op_result_or_value({0}))";
constexpr const char *optionalAppendResultTemplate =
"if {0} is not None: results.append({0})";
/// Template for appending a a list of elements to the operand/result list.
/// {0} is either 'operand' or 'result';
/// {1} is the field name.
constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})";
/// Template for appending a list of elements to the operand/result list.
/// {0} is the field name.
constexpr const char *multiOperandAppendTemplate =
"operands.extend(_get_op_results_or_values({0}))";
constexpr const char *multiOperandAppendPackTemplate =
"operands.append(_get_op_results_or_values({0}))";
constexpr const char *multiResultAppendTemplate = "results.extend({0})";
/// Template for setting an attribute in the operation builder.
/// {0} is the attribute name;
@ -625,43 +630,70 @@ static void populateBuilderLinesSuccessors(
}
/// Populates `builderLines` with additional lines that are required in the
/// builder. `kind` must be either "operand" or "result". `names` contains the
/// names of init arguments that correspond to the elements.
static void populateBuilderLines(
const Operator &op, const char *kind, llvm::ArrayRef<std::string> names,
llvm::SmallVectorImpl<std::string> &builderLines,
llvm::function_ref<int(const Operator &)> getNumElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement) {
bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
/// builder to set up op operands.
static void
populateBuilderLinesOperand(const Operator &op,
llvm::ArrayRef<std::string> names,
llvm::SmallVectorImpl<std::string> &builderLines) {
bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
// For each element, find or generate a name.
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
const NamedTypeConstraint &element = op.getOperand(i);
std::string name = names[i];
// Choose the formatting string based on the element kind.
llvm::StringRef formatString;
if (!element.isVariableLength()) {
formatString = singleElementAppendTemplate;
formatString = singleOperandAppendTemplate;
} else if (element.isOptional()) {
formatString = optionalAppendTemplate;
formatString = optionalAppendOperandTemplate;
} else {
assert(element.isVariadic() && "unhandled element group type");
// If emitting with sizedSegments, then we add the actual list typed
// element using the singleElementAppendTemplate. Otherwise, we extend
// the actual operands.
// If emitting with sizedSegments, then we add the actual list-typed
// element. Otherwise, we extend the actual operands.
if (sizedSegments) {
// Append the list as is.
formatString = singleElementAppendTemplate;
formatString = multiOperandAppendPackTemplate;
} else {
// Append the list elements.
formatString = multiElementAppendTemplate;
formatString = multiOperandAppendTemplate;
}
}
// Add the lines.
builderLines.push_back(llvm::formatv(formatString.data(), kind, name));
builderLines.push_back(llvm::formatv(formatString.data(), name));
}
}
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up op results.
static void
populateBuilderLinesResult(const Operator &op,
llvm::ArrayRef<std::string> names,
llvm::SmallVectorImpl<std::string> &builderLines) {
bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
// For each element, find or generate a name.
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
const NamedTypeConstraint &element = op.getResult(i);
std::string name = names[i];
// Choose the formatting string based on the element kind.
llvm::StringRef formatString;
if (!element.isVariableLength()) {
formatString = singleResultAppendTemplate;
} else if (element.isOptional()) {
formatString = optionalAppendResultTemplate;
} else {
assert(element.isVariadic() && "unhandled element group type");
// If emitting with sizedSegments, then we add the actual list-typed
// element. Otherwise, we extend the actual operands.
if (sizedSegments) {
formatString = singleResultAppendTemplate;
} else {
formatString = multiResultAppendTemplate;
}
}
builderLines.push_back(llvm::formatv(formatString.data(), name));
}
}
@ -680,12 +712,10 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
op.getNumNativeAttributes() + op.getNumSuccessors());
populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
populateBuilderLines(
op, "result",
llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
builderLines, getNumResults, getResult);
populateBuilderLines(op, "operand", operandArgNames, builderLines,
getNumOperands, getOperand);
populateBuilderLinesResult(
op, llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
builderLines);
populateBuilderLinesOperand(op, operandArgNames, builderLines);
populateBuilderLinesAttr(
op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
builderLines);