Reapply "[mlir][PDL] Add support for native constraints with results (#82760)"
with a small stack-use-after-scope fix in getConstraintPredicates() This reverts commit c80e6edba4a9593f0587e27fa0ac825ebe174afd.
This commit is contained in:
parent
da591d390e
commit
8ec28af8ea
@ -35,20 +35,25 @@ def PDL_ApplyNativeConstraintOp
|
||||
let description = [{
|
||||
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
|
||||
has been registered externally with the consumer of PDL, to a given set of
|
||||
entities.
|
||||
entities and optionally return a number of values.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
|
||||
pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
|
||||
// Apply constraint `with_result` to `root`. This constraint returns an attribute.
|
||||
%attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins StrAttr:$name,
|
||||
Variadic<PDL_AnyType>:$args,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
|
||||
let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
|
||||
let results = (outs Variadic<PDL_AnyType>:$results);
|
||||
let assemblyFormat = [{
|
||||
$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
@ -88,7 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
|
||||
let description = [{
|
||||
`pdl_interp.apply_constraint` operations apply a generic constraint, that
|
||||
has been registered with the interpreter, with a given set of positional
|
||||
values. On success, this operation branches to the true destination,
|
||||
values.
|
||||
The constraint function may return any number of results.
|
||||
On success, this operation branches to the true destination,
|
||||
otherwise the false destination is taken. This behavior can be reversed
|
||||
by setting the attribute `isNegated` to true.
|
||||
|
||||
@ -104,8 +106,10 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
|
||||
let arguments = (ins StrAttr:$name,
|
||||
Variadic<PDL_AnyType>:$args,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
|
||||
let results = (outs Variadic<PDL_AnyType>:$results);
|
||||
let assemblyFormat = [{
|
||||
$name `(` $args `:` type($args) `)` attr-dict `->` successors
|
||||
$name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict
|
||||
`->` successors
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -318,8 +318,9 @@ protected:
|
||||
/// A generic PDL pattern constraint function. This function applies a
|
||||
/// constraint to a given set of opaque PDLValue entities. Returns success if
|
||||
/// the constraint successfully held, failure otherwise.
|
||||
using PDLConstraintFunction =
|
||||
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
|
||||
using PDLConstraintFunction = std::function<LogicalResult(
|
||||
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
|
||||
|
||||
/// A native PDL rewrite function. This function performs a rewrite on the
|
||||
/// given set of values. Any results from this rewrite that should be passed
|
||||
/// back to PDL should be added to the provided result list. This method is only
|
||||
@ -726,7 +727,7 @@ std::enable_if_t<
|
||||
PDLConstraintFunction>
|
||||
buildConstraintFn(ConstraintFnT &&constraintFn) {
|
||||
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
|
||||
PatternRewriter &rewriter,
|
||||
PatternRewriter &rewriter, PDLResultList &,
|
||||
ArrayRef<PDLValue> values) -> LogicalResult {
|
||||
auto argIndices = std::make_index_sequence<
|
||||
llvm::function_traits<ConstraintFnT>::num_args - 1>();
|
||||
@ -842,10 +843,13 @@ public:
|
||||
/// Register a constraint function with PDL. A constraint function may be
|
||||
/// specified in one of two ways:
|
||||
///
|
||||
/// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
|
||||
/// * `LogicalResult (PatternRewriter &,
|
||||
/// PDLResultList &,
|
||||
/// ArrayRef<PDLValue>)`
|
||||
///
|
||||
/// In this overload the arguments of the constraint function are passed via
|
||||
/// the low-level PDLValue form.
|
||||
/// the low-level PDLValue form, and the results are manually appended to
|
||||
/// the given result list.
|
||||
///
|
||||
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
|
||||
///
|
||||
@ -960,8 +964,8 @@ public:
|
||||
}
|
||||
};
|
||||
class PDLResultList {};
|
||||
using PDLConstraintFunction =
|
||||
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
|
||||
using PDLConstraintFunction = std::function<LogicalResult(
|
||||
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
|
||||
using PDLRewriteFunction = std::function<LogicalResult(
|
||||
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
|
||||
|
||||
|
@ -50,7 +50,8 @@ private:
|
||||
|
||||
/// Generate interpreter operations for the tree rooted at the given matcher
|
||||
/// node, in the specified region.
|
||||
Block *generateMatcher(MatcherNode &node, Region ®ion);
|
||||
Block *generateMatcher(MatcherNode &node, Region ®ion,
|
||||
Block *block = nullptr);
|
||||
|
||||
/// Get or create an access to the provided positional value in the current
|
||||
/// block. This operation may mutate the provided block pointer if nested
|
||||
@ -148,6 +149,10 @@ private:
|
||||
/// A mapping between pattern operations and the corresponding configuration
|
||||
/// set.
|
||||
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
|
||||
|
||||
/// A mapping from a constraint question to the ApplyConstraintOp
|
||||
/// that implements it.
|
||||
DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -182,9 +187,11 @@ void PatternLowering::lower(ModuleOp module) {
|
||||
firstMatcherBlock->erase();
|
||||
}
|
||||
|
||||
Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) {
|
||||
Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion,
|
||||
Block *block) {
|
||||
// Push a new scope for the values used by this matcher.
|
||||
Block *block = ®ion.emplaceBlock();
|
||||
if (!block)
|
||||
block = ®ion.emplaceBlock();
|
||||
ValueMapScope scope(values);
|
||||
|
||||
// If this is the return node, simply insert the corresponding interpreter
|
||||
@ -364,6 +371,15 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
|
||||
loc, cast<ArrayAttr>(rawTypeAttr));
|
||||
break;
|
||||
}
|
||||
case Predicates::ConstraintResultPos: {
|
||||
// Due to the order of traversal, the ApplyConstraintOp has already been
|
||||
// created and we can find it in constraintOpMap.
|
||||
auto *constrResPos = cast<ConstraintPosition>(pos);
|
||||
auto i = constraintOpMap.find(constrResPos->getQuestion());
|
||||
assert(i != constraintOpMap.end());
|
||||
value = i->second->getResult(constrResPos->getIndex());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
llvm_unreachable("Generating unknown Position getter");
|
||||
break;
|
||||
@ -390,12 +406,11 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
|
||||
args.push_back(getValueAt(currentBlock, position));
|
||||
}
|
||||
|
||||
// Generate the matcher in the current (potentially nested) region
|
||||
// and get the failure successor.
|
||||
Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
|
||||
// Generate a new block as success successor and get the failure successor.
|
||||
Block *success = ®ion->emplaceBlock();
|
||||
Block *failure = failureBlockStack.back();
|
||||
|
||||
// Finally, create the predicate.
|
||||
// Create the predicate.
|
||||
builder.setInsertionPointToEnd(currentBlock);
|
||||
Predicates::Kind kind = question->getKind();
|
||||
switch (kind) {
|
||||
@ -447,14 +462,20 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
|
||||
}
|
||||
case Predicates::ConstraintQuestion: {
|
||||
auto *cstQuestion = cast<ConstraintQuestion>(question);
|
||||
builder.create<pdl_interp::ApplyConstraintOp>(
|
||||
loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
|
||||
failure);
|
||||
auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
|
||||
loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
|
||||
cstQuestion->getIsNegated(), success, failure);
|
||||
|
||||
constraintOpMap.insert({cstQuestion, applyConstraintOp});
|
||||
break;
|
||||
}
|
||||
default:
|
||||
llvm_unreachable("Generating unknown Predicate operation");
|
||||
}
|
||||
|
||||
// Generate the matcher in the current (potentially nested) region.
|
||||
// This might use the results of the current predicate.
|
||||
generateMatcher(*boolNode->getSuccessNode(), *region, success);
|
||||
}
|
||||
|
||||
template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
|
||||
|
@ -47,6 +47,7 @@ enum Kind : unsigned {
|
||||
OperandPos,
|
||||
OperandGroupPos,
|
||||
AttributePos,
|
||||
ConstraintResultPos,
|
||||
ResultPos,
|
||||
ResultGroupPos,
|
||||
TypePos,
|
||||
@ -279,6 +280,28 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
|
||||
bool isOperandDefiningOp() const;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstraintPosition
|
||||
|
||||
struct ConstraintQuestion;
|
||||
|
||||
/// A position describing the result of a native constraint. It saves the
|
||||
/// corresponding ConstraintQuestion and result index to enable referring
|
||||
/// back to them
|
||||
struct ConstraintPosition
|
||||
: public PredicateBase<ConstraintPosition, Position,
|
||||
std::pair<ConstraintQuestion *, unsigned>,
|
||||
Predicates::ConstraintResultPos> {
|
||||
using PredicateBase::PredicateBase;
|
||||
|
||||
/// Returns the ConstraintQuestion to enable keeping track of the native
|
||||
/// constraint this position stems from.
|
||||
ConstraintQuestion *getQuestion() const { return key.first; }
|
||||
|
||||
// Returns the result index of this position
|
||||
unsigned getIndex() const { return key.second; }
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResultPosition
|
||||
|
||||
@ -447,11 +470,13 @@ struct AttributeQuestion
|
||||
: public PredicateBase<AttributeQuestion, Qualifier, void,
|
||||
Predicates::AttributeQuestion> {};
|
||||
|
||||
/// Apply a parameterized constraint to multiple position values.
|
||||
/// Apply a parameterized constraint to multiple position values and possibly
|
||||
/// produce results.
|
||||
struct ConstraintQuestion
|
||||
: public PredicateBase<ConstraintQuestion, Qualifier,
|
||||
std::tuple<StringRef, ArrayRef<Position *>, bool>,
|
||||
Predicates::ConstraintQuestion> {
|
||||
: public PredicateBase<
|
||||
ConstraintQuestion, Qualifier,
|
||||
std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
|
||||
Predicates::ConstraintQuestion> {
|
||||
using Base::Base;
|
||||
|
||||
/// Return the name of the constraint.
|
||||
@ -460,15 +485,19 @@ struct ConstraintQuestion
|
||||
/// Return the arguments of the constraint.
|
||||
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
|
||||
|
||||
/// Return the result types of the constraint.
|
||||
ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
|
||||
|
||||
/// Return the negation status of the constraint.
|
||||
bool getIsNegated() const { return std::get<2>(key); }
|
||||
bool getIsNegated() const { return std::get<3>(key); }
|
||||
|
||||
/// Construct an instance with the given storage allocator.
|
||||
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
|
||||
KeyTy key) {
|
||||
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
|
||||
alloc.copyInto(std::get<1>(key)),
|
||||
std::get<2>(key)});
|
||||
alloc.copyInto(std::get<2>(key)),
|
||||
std::get<3>(key)});
|
||||
}
|
||||
|
||||
/// Returns a hash suitable for the given keytype.
|
||||
@ -526,6 +555,7 @@ public:
|
||||
// Register the types of Positions with the uniquer.
|
||||
registerParametricStorageType<AttributePosition>();
|
||||
registerParametricStorageType<AttributeLiteralPosition>();
|
||||
registerParametricStorageType<ConstraintPosition>();
|
||||
registerParametricStorageType<ForEachPosition>();
|
||||
registerParametricStorageType<OperandPosition>();
|
||||
registerParametricStorageType<OperandGroupPosition>();
|
||||
@ -588,6 +618,12 @@ public:
|
||||
return OperationPosition::get(uniquer, p);
|
||||
}
|
||||
|
||||
// Returns a position for a new value created by a constraint.
|
||||
ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
|
||||
unsigned index) {
|
||||
return ConstraintPosition::get(uniquer, std::make_pair(q, index));
|
||||
}
|
||||
|
||||
/// Returns an attribute position for an attribute of the given operation.
|
||||
Position *getAttribute(OperationPosition *p, StringRef name) {
|
||||
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
|
||||
@ -673,11 +709,11 @@ public:
|
||||
}
|
||||
|
||||
/// Create a predicate that applies a generic constraint.
|
||||
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
|
||||
bool isNegated) {
|
||||
return {
|
||||
ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
|
||||
TrueAnswer::get(uniquer)};
|
||||
Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
|
||||
ArrayRef<Type> resultTypes, bool isNegated) {
|
||||
return {ConstraintQuestion::get(
|
||||
uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
|
||||
TrueAnswer::get(uniquer)};
|
||||
}
|
||||
|
||||
/// Create a predicate comparing a value with null.
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include <queue>
|
||||
@ -49,14 +50,15 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
|
||||
DenseMap<Value, Position *> &inputs,
|
||||
AttributePosition *pos) {
|
||||
assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
|
||||
pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
|
||||
predList.emplace_back(pos, builder.getIsNotNull());
|
||||
|
||||
// If the attribute has a type or value, add a constraint.
|
||||
if (Value type = attr.getValueType())
|
||||
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
|
||||
else if (Attribute value = attr.getValueAttr())
|
||||
predList.emplace_back(pos, builder.getAttributeConstraint(value));
|
||||
if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
|
||||
// If the attribute has a type or value, add a constraint.
|
||||
if (Value type = attr.getValueType())
|
||||
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
|
||||
else if (Attribute value = attr.getValueAttr())
|
||||
predList.emplace_back(pos, builder.getAttributeConstraint(value));
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect all of the predicates for the given operand position.
|
||||
@ -272,8 +274,27 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
|
||||
// Push the constraint to the furthest position.
|
||||
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
|
||||
comparePosDepth);
|
||||
PredicateBuilder::Predicate pred =
|
||||
builder.getConstraint(op.getName(), allPositions, op.getIsNegated());
|
||||
ResultRange results = op.getResults();
|
||||
PredicateBuilder::Predicate pred = builder.getConstraint(
|
||||
op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
|
||||
op.getIsNegated());
|
||||
|
||||
// For each result register a position so it can be used later
|
||||
for (auto [i, result] : llvm::enumerate(results)) {
|
||||
ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
|
||||
ConstraintPosition *pos = builder.getConstraintPosition(q, i);
|
||||
auto [it, inserted] = inputs.try_emplace(result, pos);
|
||||
// If this is an input value that has been visited in the tree, add a
|
||||
// constraint to ensure that both instances refer to the same value.
|
||||
if (!inserted) {
|
||||
Position *first = pos;
|
||||
Position *second = it->second;
|
||||
if (comparePosDepth(second, first))
|
||||
std::tie(second, first) = std::make_pair(first, second);
|
||||
|
||||
predList.emplace_back(second, builder.getEqualTo(first));
|
||||
}
|
||||
}
|
||||
predList.emplace_back(pos, pred);
|
||||
}
|
||||
|
||||
@ -875,6 +896,49 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
|
||||
*root = std::make_unique<ExitNode>();
|
||||
}
|
||||
|
||||
/// Sorts the range begin/end with the partial order given by cmp.
|
||||
template <typename Iterator, typename Compare>
|
||||
static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
|
||||
while (begin != end) {
|
||||
// Cannot compute sortBeforeOthers in the predicate of stable_partition
|
||||
// because stable_partition will not keep the [begin, end) range intact
|
||||
// while it runs.
|
||||
llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
|
||||
for (auto i = begin; i != end; ++i) {
|
||||
if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
|
||||
sortBeforeOthers.insert(*i);
|
||||
}
|
||||
|
||||
auto const next = std::stable_partition(begin, end, [&](auto const &a) {
|
||||
return sortBeforeOthers.contains(a);
|
||||
});
|
||||
assert(next != begin && "not a partial ordering");
|
||||
begin = next;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if 'b' depends on a result of 'a'.
|
||||
static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
|
||||
auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
|
||||
if (!cqa)
|
||||
return false;
|
||||
|
||||
auto positionDependsOnA = [&](Position *p) {
|
||||
auto *cp = dyn_cast<ConstraintPosition>(p);
|
||||
return cp && cp->getQuestion() == cqa;
|
||||
};
|
||||
|
||||
if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
|
||||
// Does any argument of b use a?
|
||||
return llvm::any_of(cqb->getArgs(), positionDependsOnA);
|
||||
}
|
||||
if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
|
||||
return positionDependsOnA(b->position) ||
|
||||
positionDependsOnA(equalTo->getValue());
|
||||
}
|
||||
return positionDependsOnA(b->position);
|
||||
}
|
||||
|
||||
/// Given a module containing PDL pattern operations, generate a matcher tree
|
||||
/// using the patterns within the given module and return the root matcher node.
|
||||
std::unique_ptr<MatcherNode>
|
||||
@ -955,6 +1019,10 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
|
||||
return *lhs < *rhs;
|
||||
});
|
||||
|
||||
// Mostly keep the now established order, but also ensure that
|
||||
// ConstraintQuestions come after the results they use.
|
||||
stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
|
||||
|
||||
// Build the matchers for each of the pattern predicate lists.
|
||||
std::unique_ptr<MatcherNode> root;
|
||||
for (OrderedPredicateList &list : lists)
|
||||
|
@ -94,6 +94,12 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
|
||||
LogicalResult ApplyNativeConstraintOp::verify() {
|
||||
if (getNumOperands() == 0)
|
||||
return emitOpError("expected at least one argument");
|
||||
if (llvm::any_of(getResults(), [](OpResult result) {
|
||||
return isa<OperationType>(result.getType());
|
||||
})) {
|
||||
return emitOpError(
|
||||
"returning an operation from a constraint is not supported");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -769,11 +769,25 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
|
||||
|
||||
void Generator::generate(pdl_interp::ApplyConstraintOp op,
|
||||
ByteCodeWriter &writer) {
|
||||
assert(constraintToMemIndex.count(op.getName()) &&
|
||||
"expected index for constraint function");
|
||||
// Constraints that should return a value have to be registered as rewrites.
|
||||
// If a constraint and a rewrite of similar name are registered the
|
||||
// constraint takes precedence
|
||||
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
|
||||
writer.appendPDLValueList(op.getArgs());
|
||||
writer.append(ByteCodeField(op.getIsNegated()));
|
||||
ResultRange results = op.getResults();
|
||||
writer.append(ByteCodeField(results.size()));
|
||||
for (Value result : results) {
|
||||
// We record the expected kind of the result, so that we can provide extra
|
||||
// verification of the native rewrite function and handle the failure case
|
||||
// of constraints accordingly.
|
||||
writer.appendPDLValueKind(result);
|
||||
|
||||
// Range results also need to append the range storage index.
|
||||
if (isa<pdl::RangeType>(result.getType()))
|
||||
writer.append(getRangeStorageIndex(result));
|
||||
writer.append(result);
|
||||
}
|
||||
writer.append(op.getSuccessors());
|
||||
}
|
||||
void Generator::generate(pdl_interp::ApplyRewriteOp op,
|
||||
@ -786,11 +800,9 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
|
||||
ResultRange results = op.getResults();
|
||||
writer.append(ByteCodeField(results.size()));
|
||||
for (Value result : results) {
|
||||
// In debug mode we also record the expected kind of the result, so that we
|
||||
// We record the expected kind of the result, so that we
|
||||
// can provide extra verification of the native rewrite function.
|
||||
#ifndef NDEBUG
|
||||
writer.appendPDLValueKind(result);
|
||||
#endif
|
||||
|
||||
// Range results also need to append the range storage index.
|
||||
if (isa<pdl::RangeType>(result.getType()))
|
||||
@ -1076,6 +1088,28 @@ void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
|
||||
// ByteCode Execution
|
||||
|
||||
namespace {
|
||||
/// This class is an instantiation of the PDLResultList that provides access to
|
||||
/// the returned results. This API is not on `PDLResultList` to avoid
|
||||
/// overexposing access to information specific solely to the ByteCode.
|
||||
class ByteCodeRewriteResultList : public PDLResultList {
|
||||
public:
|
||||
ByteCodeRewriteResultList(unsigned maxNumResults)
|
||||
: PDLResultList(maxNumResults) {}
|
||||
|
||||
/// Return the list of PDL results.
|
||||
MutableArrayRef<PDLValue> getResults() { return results; }
|
||||
|
||||
/// Return the type ranges allocated by this list.
|
||||
MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
|
||||
return allocatedTypeRanges;
|
||||
}
|
||||
|
||||
/// Return the value ranges allocated by this list.
|
||||
MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
|
||||
return allocatedValueRanges;
|
||||
}
|
||||
};
|
||||
|
||||
/// This class provides support for executing a bytecode stream.
|
||||
class ByteCodeExecutor {
|
||||
public:
|
||||
@ -1152,6 +1186,9 @@ private:
|
||||
void executeSwitchResultCount();
|
||||
void executeSwitchType();
|
||||
void executeSwitchTypes();
|
||||
void processNativeFunResults(ByteCodeRewriteResultList &results,
|
||||
unsigned numResults,
|
||||
LogicalResult &rewriteResult);
|
||||
|
||||
/// Pushes a code iterator to the stack.
|
||||
void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
|
||||
@ -1225,6 +1262,8 @@ private:
|
||||
return T::getFromOpaquePointer(pointer);
|
||||
}
|
||||
|
||||
void skip(size_t skipN) { curCodeIt += skipN; }
|
||||
|
||||
/// Jump to a specific successor based on a predicate value.
|
||||
void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
|
||||
/// Jump to a specific successor based on a destination index.
|
||||
@ -1381,33 +1420,11 @@ private:
|
||||
ArrayRef<PDLConstraintFunction> constraintFunctions;
|
||||
ArrayRef<PDLRewriteFunction> rewriteFunctions;
|
||||
};
|
||||
|
||||
/// This class is an instantiation of the PDLResultList that provides access to
|
||||
/// the returned results. This API is not on `PDLResultList` to avoid
|
||||
/// overexposing access to information specific solely to the ByteCode.
|
||||
class ByteCodeRewriteResultList : public PDLResultList {
|
||||
public:
|
||||
ByteCodeRewriteResultList(unsigned maxNumResults)
|
||||
: PDLResultList(maxNumResults) {}
|
||||
|
||||
/// Return the list of PDL results.
|
||||
MutableArrayRef<PDLValue> getResults() { return results; }
|
||||
|
||||
/// Return the type ranges allocated by this list.
|
||||
MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
|
||||
return allocatedTypeRanges;
|
||||
}
|
||||
|
||||
/// Return the value ranges allocated by this list.
|
||||
MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
|
||||
return allocatedValueRanges;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
|
||||
const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
|
||||
ByteCodeField fun_idx = read();
|
||||
SmallVector<PDLValue, 16> args;
|
||||
readList<PDLValue>(args);
|
||||
|
||||
@ -1422,8 +1439,29 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
|
||||
llvm::dbgs() << " * isNegated: " << isNegated << "\n";
|
||||
llvm::interleaveComma(args, llvm::dbgs());
|
||||
});
|
||||
// Invoke the constraint and jump to the proper destination.
|
||||
selectJump(isNegated != succeeded(constraintFn(rewriter, args)));
|
||||
|
||||
ByteCodeField numResults = read();
|
||||
const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
|
||||
ByteCodeRewriteResultList results(numResults);
|
||||
LogicalResult rewriteResult = constraintFn(rewriter, results, args);
|
||||
ArrayRef<PDLValue> constraintResults = results.getResults();
|
||||
LLVM_DEBUG({
|
||||
if (succeeded(rewriteResult)) {
|
||||
llvm::dbgs() << " * Constraint succeeded\n";
|
||||
llvm::dbgs() << " * Results: ";
|
||||
llvm::interleaveComma(constraintResults, llvm::dbgs());
|
||||
llvm::dbgs() << "\n";
|
||||
} else {
|
||||
llvm::dbgs() << " * Constraint failed\n";
|
||||
}
|
||||
});
|
||||
assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
|
||||
"native PDL rewrite function succeeded but returned "
|
||||
"unexpected number of results");
|
||||
processNativeFunResults(results, numResults, rewriteResult);
|
||||
|
||||
// Depending on the constraint jump to the proper destination.
|
||||
selectJump(isNegated != succeeded(rewriteResult));
|
||||
}
|
||||
|
||||
LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
|
||||
@ -1445,16 +1483,39 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
|
||||
assert(results.getResults().size() == numResults &&
|
||||
"native PDL rewrite function returned unexpected number of results");
|
||||
|
||||
// Store the results in the bytecode memory.
|
||||
for (PDLValue &result : results.getResults()) {
|
||||
processNativeFunResults(results, numResults, rewriteResult);
|
||||
|
||||
if (failed(rewriteResult)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << " - Failed");
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void ByteCodeExecutor::processNativeFunResults(
|
||||
ByteCodeRewriteResultList &results, unsigned numResults,
|
||||
LogicalResult &rewriteResult) {
|
||||
// Store the results in the bytecode memory or handle missing results on
|
||||
// failure.
|
||||
for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
|
||||
PDLValue::Kind resultKind = read<PDLValue::Kind>();
|
||||
|
||||
// Skip the according number of values on the buffer on failure and exit
|
||||
// early as there are no results to process.
|
||||
if (failed(rewriteResult)) {
|
||||
if (resultKind == PDLValue::Kind::TypeRange ||
|
||||
resultKind == PDLValue::Kind::ValueRange) {
|
||||
skip(2);
|
||||
} else {
|
||||
skip(1);
|
||||
}
|
||||
return;
|
||||
}
|
||||
PDLValue result = results.getResults()[resultIdx];
|
||||
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
|
||||
|
||||
// In debug mode we also verify the expected kind of the result.
|
||||
#ifndef NDEBUG
|
||||
assert(result.getKind() == read<PDLValue::Kind>() &&
|
||||
"native PDL rewrite function returned an unexpected type of result");
|
||||
#endif
|
||||
|
||||
assert(result.getKind() == resultKind &&
|
||||
"native PDL rewrite function returned an unexpected type of "
|
||||
"result");
|
||||
// If the result is a range, we need to copy it over to the bytecodes
|
||||
// range memory.
|
||||
if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
|
||||
@ -1476,13 +1537,6 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
|
||||
allocatedTypeRangeMemory.push_back(std::move(it));
|
||||
for (auto &it : results.getAllocatedValueRanges())
|
||||
allocatedValueRangeMemory.push_back(std::move(it));
|
||||
|
||||
// Process the result of the rewrite.
|
||||
if (failed(rewriteResult)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << " - Failed");
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void ByteCodeExecutor::executeAreEqual() {
|
||||
|
@ -1362,12 +1362,6 @@ FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
|
||||
if (failed(parseToken(Token::semicolon,
|
||||
"expected `;` after native declaration")))
|
||||
return failure();
|
||||
// TODO: PDL should be able to support constraint results in certain
|
||||
// situations, we should revise this.
|
||||
if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
|
||||
return emitError(
|
||||
"native Constraints currently do not support returning results");
|
||||
}
|
||||
return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
|
||||
}
|
||||
|
||||
|
@ -79,6 +79,57 @@ module @constraints {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module @constraint_with_result
|
||||
module @constraint_with_result {
|
||||
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
|
||||
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
|
||||
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
|
||||
pdl.pattern : benefit(1) {
|
||||
%root = operation
|
||||
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
|
||||
rewrite %root with "rewriter"(%attr : !pdl.attribute)
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module @constraint_with_unused_result
|
||||
module @constraint_with_unused_result {
|
||||
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
|
||||
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
|
||||
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation)
|
||||
pdl.pattern : benefit(1) {
|
||||
%root = operation
|
||||
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
|
||||
rewrite %root with "rewriter"
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module @constraint_with_result_multiple
|
||||
module @constraint_with_result_multiple {
|
||||
// check that native constraints work as expected even when multiple identical constraints are fused
|
||||
|
||||
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
|
||||
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
|
||||
// CHECK-NOT: pdl_interp.apply_constraint "check_op_and_get_attr_constr"
|
||||
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter_0(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
|
||||
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
|
||||
pdl.pattern : benefit(1) {
|
||||
%root = operation
|
||||
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
|
||||
rewrite %root with "rewriter"(%attr : !pdl.attribute)
|
||||
}
|
||||
pdl.pattern : benefit(1) {
|
||||
%root = operation
|
||||
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
|
||||
rewrite %root with "rewriter"(%attr : !pdl.attribute)
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module @negated_constraint
|
||||
module @negated_constraint {
|
||||
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
|
||||
|
@ -0,0 +1,77 @@
|
||||
// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s
|
||||
|
||||
// Ensuse that the dependency between add & less
|
||||
// causes them to be in the correct order.
|
||||
// CHECK-LABEL: matcher
|
||||
// CHECK: apply_constraint "return_attr_constraint"
|
||||
// CHECK: apply_constraint "use_attr_constraint"
|
||||
|
||||
module {
|
||||
pdl.pattern : benefit(1) {
|
||||
%0 = attribute
|
||||
%1 = types
|
||||
%2 = operation "tosa.mul" {"shift" = %0} -> (%1 : !pdl.range<type>)
|
||||
%3 = attribute = 0 : i32
|
||||
%4 = attribute = 1 : i32
|
||||
%5 = apply_native_constraint "return_attr_constraint"(%3, %4 : !pdl.attribute, !pdl.attribute) : !pdl.attribute
|
||||
apply_native_constraint "use_attr_constraint"(%0, %5 : !pdl.attribute, !pdl.attribute)
|
||||
rewrite %2 with "rewriter"
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: matcher
|
||||
// CHECK: %[[ATTR:.*]] = pdl_interp.get_attribute "attr" of
|
||||
// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_attr_constraint"
|
||||
// CHECK: pdl_interp.are_equal %[[ATTR:.*]], %[[CONSTRAINT:.*]]
|
||||
|
||||
pdl.pattern : benefit(1) {
|
||||
%inputOp = operation
|
||||
%result = result 0 of %inputOp
|
||||
%attr = pdl.apply_native_constraint "return_attr_constraint"(%inputOp : !pdl.operation) : !pdl.attribute
|
||||
%root = operation(%result : !pdl.value) {"attr" = %attr}
|
||||
rewrite %root with "rewriter"(%attr : !pdl.attribute)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: matcher
|
||||
// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_value_constr"
|
||||
// CHECK: %[[VALUE:.*]] = pdl_interp.get_operand 0
|
||||
// CHECK: pdl_interp.are_equal %[[VALUE:.*]], %[[CONSTRAINT:.*]]
|
||||
|
||||
pdl.pattern : benefit(1) {
|
||||
%attr = attribute = 10
|
||||
%value = pdl.apply_native_constraint "return_value_constr"(%attr: !pdl.attribute) : !pdl.value
|
||||
%root = operation(%value : !pdl.value)
|
||||
rewrite %root with "rewriter"
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: matcher
|
||||
// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_constr"
|
||||
// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of
|
||||
// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]]
|
||||
|
||||
pdl.pattern : benefit(1) {
|
||||
%attr = attribute = 10
|
||||
%type = pdl.apply_native_constraint "return_type_constr"(%attr: !pdl.attribute) : !pdl.type
|
||||
%root = operation -> (%type : !pdl.type)
|
||||
rewrite %root with "rewriter"
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: matcher
|
||||
// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_range_constr"
|
||||
// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of
|
||||
// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]]
|
||||
|
||||
pdl.pattern : benefit(1) {
|
||||
%attr = attribute = 10
|
||||
%types = pdl.apply_native_constraint "return_type_range_constr"(%attr: !pdl.attribute) : !pdl.range<type>
|
||||
%root = operation -> (%types : !pdl.range<type>)
|
||||
rewrite %root with "rewriter"
|
||||
}
|
@ -134,6 +134,24 @@ pdl.pattern @apply_rewrite_with_no_results : benefit(1) {
|
||||
|
||||
// -----
|
||||
|
||||
pdl.pattern @apply_constraint_with_no_results : benefit(1) {
|
||||
%root = operation
|
||||
apply_native_constraint "NativeConstraint"(%root : !pdl.operation)
|
||||
rewrite %root with "rewriter"
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
pdl.pattern @apply_constraint_with_results : benefit(1) {
|
||||
%root = operation
|
||||
%attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute
|
||||
rewrite %root {
|
||||
apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute)
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
pdl.pattern @attribute_with_dict : benefit(1) {
|
||||
%root = operation
|
||||
rewrite %root {
|
||||
|
@ -109,6 +109,74 @@ module @ir attributes { test.apply_constraint_3 } {
|
||||
|
||||
// -----
|
||||
|
||||
// Test returning a type from a native constraint.
|
||||
module @patterns {
|
||||
pdl_interp.func @matcher(%root : !pdl.operation) {
|
||||
pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
|
||||
|
||||
^pat:
|
||||
%new_type = pdl_interp.apply_constraint "op_constr_return_type"(%root : !pdl.operation) : !pdl.type -> ^pat2, ^end
|
||||
|
||||
^pat2:
|
||||
pdl_interp.record_match @rewriters::@success(%root, %new_type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end
|
||||
|
||||
^end:
|
||||
pdl_interp.finalize
|
||||
}
|
||||
|
||||
module @rewriters {
|
||||
pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) {
|
||||
%op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type)
|
||||
pdl_interp.erase %root
|
||||
pdl_interp.finalize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test.apply_constraint_4
|
||||
// CHECK-NOT: "test.replaced_by_pattern"
|
||||
// CHECK: "test.replaced_by_pattern"() : () -> f32
|
||||
module @ir attributes { test.apply_constraint_4 } {
|
||||
"test.failure_op"() : () -> ()
|
||||
"test.success_op"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test success and failure cases of native constraints with pdl.range results.
|
||||
module @patterns {
|
||||
pdl_interp.func @matcher(%root : !pdl.operation) {
|
||||
pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
|
||||
|
||||
^pat:
|
||||
%num_results = pdl_interp.create_attribute 2 : i32
|
||||
%types = pdl_interp.apply_constraint "op_constr_return_type_range"(%root, %num_results : !pdl.operation, !pdl.attribute) : !pdl.range<type> -> ^pat1, ^end
|
||||
|
||||
^pat1:
|
||||
pdl_interp.record_match @rewriters::@success(%root, %types : !pdl.operation, !pdl.range<type>) : benefit(1), loc([%root]) -> ^end
|
||||
|
||||
^end:
|
||||
pdl_interp.finalize
|
||||
}
|
||||
|
||||
module @rewriters {
|
||||
pdl_interp.func @success(%root : !pdl.operation, %types : !pdl.range<type>) {
|
||||
%op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%types : !pdl.range<type>)
|
||||
pdl_interp.erase %root
|
||||
pdl_interp.finalize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test.apply_constraint_5
|
||||
// CHECK-NOT: "test.replaced_by_pattern"
|
||||
// CHECK: "test.replaced_by_pattern"() : () -> (f32, f32)
|
||||
module @ir attributes { test.apply_constraint_5 } {
|
||||
"test.failure_op"() : () -> ()
|
||||
"test.success_op"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl_interp::ApplyRewriteOp
|
||||
|
@ -887,7 +887,7 @@ public:
|
||||
#include "TestTransformDialectExtensionTypes.cpp.inc"
|
||||
>();
|
||||
|
||||
auto verboseConstraint = [](PatternRewriter &rewriter,
|
||||
auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &,
|
||||
ArrayRef<PDLValue> pdlValues) {
|
||||
for (const PDLValue &pdlValue : pdlValues) {
|
||||
if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
|
||||
|
@ -30,6 +30,50 @@ static LogicalResult customMultiEntityVariadicConstraint(
|
||||
return success();
|
||||
}
|
||||
|
||||
// Custom constraint that returns a value if the op is named test.success_op
|
||||
static LogicalResult customValueResultConstraint(PatternRewriter &rewriter,
|
||||
PDLResultList &results,
|
||||
ArrayRef<PDLValue> args) {
|
||||
auto *op = args[0].cast<Operation *>();
|
||||
if (op->getName().getStringRef() == "test.success_op") {
|
||||
StringAttr customAttr = rewriter.getStringAttr("test.success");
|
||||
results.push_back(customAttr);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Custom constraint that returns a type if the op is named test.success_op
|
||||
static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
|
||||
PDLResultList &results,
|
||||
ArrayRef<PDLValue> args) {
|
||||
auto *op = args[0].cast<Operation *>();
|
||||
if (op->getName().getStringRef() == "test.success_op") {
|
||||
results.push_back(rewriter.getF32Type());
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Custom constraint that returns a type range of variable length if the op is
|
||||
// named test.success_op
|
||||
static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
|
||||
PDLResultList &results,
|
||||
ArrayRef<PDLValue> args) {
|
||||
auto *op = args[0].cast<Operation *>();
|
||||
int numTypes = args[1].cast<Attribute>().cast<IntegerAttr>().getInt();
|
||||
|
||||
if (op->getName().getStringRef() == "test.success_op") {
|
||||
SmallVector<Type> types;
|
||||
for (int i = 0; i < numTypes; i++) {
|
||||
types.push_back(rewriter.getF32Type());
|
||||
}
|
||||
results.push_back(TypeRange(types));
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Custom creator invoked from PDL.
|
||||
static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
|
||||
return rewriter.create(OperationState(op->getLoc(), "test.success"));
|
||||
@ -102,6 +146,12 @@ struct TestPDLByteCodePass
|
||||
customMultiEntityConstraint);
|
||||
pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
|
||||
customMultiEntityVariadicConstraint);
|
||||
pdlPattern.registerConstraintFunction("op_constr_return_attr",
|
||||
customValueResultConstraint);
|
||||
pdlPattern.registerConstraintFunction("op_constr_return_type",
|
||||
customTypeResultConstraint);
|
||||
pdlPattern.registerConstraintFunction("op_constr_return_type_range",
|
||||
customTypeRangeResultConstraint);
|
||||
pdlPattern.registerRewriteFunction("creator", customCreate);
|
||||
pdlPattern.registerRewriteFunction("var_creator",
|
||||
customVariadicResultCreate);
|
||||
|
@ -158,8 +158,3 @@ Pattern {
|
||||
|
||||
// CHECK: expected `;` after native declaration
|
||||
Constraint Foo() [{}]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: native Constraints currently do not support returning results
|
||||
Constraint Foo() -> Op;
|
||||
|
@ -12,6 +12,14 @@ Constraint Foo() [{ /* Native Code */ }];
|
||||
|
||||
// -----
|
||||
|
||||
// Test that native constraints support returning results.
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Attr>
|
||||
Constraint Foo() -> Attr;
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
|
||||
// CHECK: `Inputs`
|
||||
|
@ -298,6 +298,6 @@ def test_apply_native_constraint():
|
||||
pattern = PatternOp(1)
|
||||
with InsertionPoint(pattern.body):
|
||||
resultType = TypeOp()
|
||||
ApplyNativeConstraintOp("typeConstraint", args=[resultType])
|
||||
ApplyNativeConstraintOp([], "typeConstraint", args=[resultType])
|
||||
root = OperationOp(types=[resultType])
|
||||
RewriteOp(root, name="rewrite")
|
||||
|
Loading…
x
Reference in New Issue
Block a user