Reapply "[mlir][PDL] Add support for native constraints with results (#82760)"

with a small stack-use-after-scope fix in getConstraintPredicates()

This reverts commit c80e6edba4a9593f0587e27fa0ac825ebe174afd.
This commit is contained in:
Matthias Gehre 2024-03-01 23:32:27 +01:00
parent da591d390e
commit 8ec28af8ea
18 changed files with 558 additions and 99 deletions

View File

@ -35,20 +35,25 @@ def PDL_ApplyNativeConstraintOp
let description = [{
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
has been registered externally with the consumer of PDL, to a given set of
entities.
entities and optionally return a number of values.
Example:
```mlir
// Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
// Apply constraint `with_result` to `root`. This constraint returns an attribute.
%attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
```
}];
let arguments = (ins StrAttr:$name,
Variadic<PDL_AnyType>:$args,
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = [{
$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict
}];
let hasVerifier = 1;
}

View File

@ -88,7 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
let description = [{
`pdl_interp.apply_constraint` operations apply a generic constraint, that
has been registered with the interpreter, with a given set of positional
values. On success, this operation branches to the true destination,
values.
The constraint function may return any number of results.
On success, this operation branches to the true destination,
otherwise the false destination is taken. This behavior can be reversed
by setting the attribute `isNegated` to true.
@ -104,8 +106,10 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
let arguments = (ins StrAttr:$name,
Variadic<PDL_AnyType>:$args,
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = [{
$name `(` $args `:` type($args) `)` attr-dict `->` successors
$name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict
`->` successors
}];
}

View File

@ -318,8 +318,9 @@ protected:
/// A generic PDL pattern constraint function. This function applies a
/// constraint to a given set of opaque PDLValue entities. Returns success if
/// the constraint successfully held, failure otherwise.
using PDLConstraintFunction =
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
using PDLConstraintFunction = std::function<LogicalResult(
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
/// A native PDL rewrite function. This function performs a rewrite on the
/// given set of values. Any results from this rewrite that should be passed
/// back to PDL should be added to the provided result list. This method is only
@ -726,7 +727,7 @@ std::enable_if_t<
PDLConstraintFunction>
buildConstraintFn(ConstraintFnT &&constraintFn) {
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
PatternRewriter &rewriter,
PatternRewriter &rewriter, PDLResultList &,
ArrayRef<PDLValue> values) -> LogicalResult {
auto argIndices = std::make_index_sequence<
llvm::function_traits<ConstraintFnT>::num_args - 1>();
@ -842,10 +843,13 @@ public:
/// Register a constraint function with PDL. A constraint function may be
/// specified in one of two ways:
///
/// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
/// * `LogicalResult (PatternRewriter &,
/// PDLResultList &,
/// ArrayRef<PDLValue>)`
///
/// In this overload the arguments of the constraint function are passed via
/// the low-level PDLValue form.
/// the low-level PDLValue form, and the results are manually appended to
/// the given result list.
///
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
///
@ -960,8 +964,8 @@ public:
}
};
class PDLResultList {};
using PDLConstraintFunction =
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
using PDLConstraintFunction = std::function<LogicalResult(
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
using PDLRewriteFunction = std::function<LogicalResult(
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;

View File

@ -50,7 +50,8 @@ private:
/// Generate interpreter operations for the tree rooted at the given matcher
/// node, in the specified region.
Block *generateMatcher(MatcherNode &node, Region &region);
Block *generateMatcher(MatcherNode &node, Region &region,
Block *block = nullptr);
/// Get or create an access to the provided positional value in the current
/// block. This operation may mutate the provided block pointer if nested
@ -148,6 +149,10 @@ private:
/// A mapping between pattern operations and the corresponding configuration
/// set.
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
/// A mapping from a constraint question to the ApplyConstraintOp
/// that implements it.
DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
};
} // namespace
@ -182,9 +187,11 @@ void PatternLowering::lower(ModuleOp module) {
firstMatcherBlock->erase();
}
Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
Block *block) {
// Push a new scope for the values used by this matcher.
Block *block = &region.emplaceBlock();
if (!block)
block = &region.emplaceBlock();
ValueMapScope scope(values);
// If this is the return node, simply insert the corresponding interpreter
@ -364,6 +371,15 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
loc, cast<ArrayAttr>(rawTypeAttr));
break;
}
case Predicates::ConstraintResultPos: {
// Due to the order of traversal, the ApplyConstraintOp has already been
// created and we can find it in constraintOpMap.
auto *constrResPos = cast<ConstraintPosition>(pos);
auto i = constraintOpMap.find(constrResPos->getQuestion());
assert(i != constraintOpMap.end());
value = i->second->getResult(constrResPos->getIndex());
break;
}
default:
llvm_unreachable("Generating unknown Position getter");
break;
@ -390,12 +406,11 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
args.push_back(getValueAt(currentBlock, position));
}
// Generate the matcher in the current (potentially nested) region
// and get the failure successor.
Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
// Generate a new block as success successor and get the failure successor.
Block *success = &region->emplaceBlock();
Block *failure = failureBlockStack.back();
// Finally, create the predicate.
// Create the predicate.
builder.setInsertionPointToEnd(currentBlock);
Predicates::Kind kind = question->getKind();
switch (kind) {
@ -447,14 +462,20 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
}
case Predicates::ConstraintQuestion: {
auto *cstQuestion = cast<ConstraintQuestion>(question);
builder.create<pdl_interp::ApplyConstraintOp>(
loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
failure);
auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
cstQuestion->getIsNegated(), success, failure);
constraintOpMap.insert({cstQuestion, applyConstraintOp});
break;
}
default:
llvm_unreachable("Generating unknown Predicate operation");
}
// Generate the matcher in the current (potentially nested) region.
// This might use the results of the current predicate.
generateMatcher(*boolNode->getSuccessNode(), *region, success);
}
template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>

View File

@ -47,6 +47,7 @@ enum Kind : unsigned {
OperandPos,
OperandGroupPos,
AttributePos,
ConstraintResultPos,
ResultPos,
ResultGroupPos,
TypePos,
@ -279,6 +280,28 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
bool isOperandDefiningOp() const;
};
//===----------------------------------------------------------------------===//
// ConstraintPosition
struct ConstraintQuestion;
/// A position describing the result of a native constraint. It saves the
/// corresponding ConstraintQuestion and result index to enable referring
/// back to them
struct ConstraintPosition
: public PredicateBase<ConstraintPosition, Position,
std::pair<ConstraintQuestion *, unsigned>,
Predicates::ConstraintResultPos> {
using PredicateBase::PredicateBase;
/// Returns the ConstraintQuestion to enable keeping track of the native
/// constraint this position stems from.
ConstraintQuestion *getQuestion() const { return key.first; }
// Returns the result index of this position
unsigned getIndex() const { return key.second; }
};
//===----------------------------------------------------------------------===//
// ResultPosition
@ -447,11 +470,13 @@ struct AttributeQuestion
: public PredicateBase<AttributeQuestion, Qualifier, void,
Predicates::AttributeQuestion> {};
/// Apply a parameterized constraint to multiple position values.
/// Apply a parameterized constraint to multiple position values and possibly
/// produce results.
struct ConstraintQuestion
: public PredicateBase<ConstraintQuestion, Qualifier,
std::tuple<StringRef, ArrayRef<Position *>, bool>,
Predicates::ConstraintQuestion> {
: public PredicateBase<
ConstraintQuestion, Qualifier,
std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
Predicates::ConstraintQuestion> {
using Base::Base;
/// Return the name of the constraint.
@ -460,15 +485,19 @@ struct ConstraintQuestion
/// Return the arguments of the constraint.
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
/// Return the result types of the constraint.
ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
/// Return the negation status of the constraint.
bool getIsNegated() const { return std::get<2>(key); }
bool getIsNegated() const { return std::get<3>(key); }
/// Construct an instance with the given storage allocator.
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
KeyTy key) {
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
alloc.copyInto(std::get<1>(key)),
std::get<2>(key)});
alloc.copyInto(std::get<2>(key)),
std::get<3>(key)});
}
/// Returns a hash suitable for the given keytype.
@ -526,6 +555,7 @@ public:
// Register the types of Positions with the uniquer.
registerParametricStorageType<AttributePosition>();
registerParametricStorageType<AttributeLiteralPosition>();
registerParametricStorageType<ConstraintPosition>();
registerParametricStorageType<ForEachPosition>();
registerParametricStorageType<OperandPosition>();
registerParametricStorageType<OperandGroupPosition>();
@ -588,6 +618,12 @@ public:
return OperationPosition::get(uniquer, p);
}
// Returns a position for a new value created by a constraint.
ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
unsigned index) {
return ConstraintPosition::get(uniquer, std::make_pair(q, index));
}
/// Returns an attribute position for an attribute of the given operation.
Position *getAttribute(OperationPosition *p, StringRef name) {
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
@ -673,11 +709,11 @@ public:
}
/// Create a predicate that applies a generic constraint.
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
bool isNegated) {
return {
ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
TrueAnswer::get(uniquer)};
Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
ArrayRef<Type> resultTypes, bool isNegated) {
return {ConstraintQuestion::get(
uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
TrueAnswer::get(uniquer)};
}
/// Create a predicate comparing a value with null.

View File

@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <queue>
@ -49,14 +50,15 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
DenseMap<Value, Position *> &inputs,
AttributePosition *pos) {
assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
predList.emplace_back(pos, builder.getIsNotNull());
// If the attribute has a type or value, add a constraint.
if (Value type = attr.getValueType())
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
else if (Attribute value = attr.getValueAttr())
predList.emplace_back(pos, builder.getAttributeConstraint(value));
if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
// If the attribute has a type or value, add a constraint.
if (Value type = attr.getValueType())
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
else if (Attribute value = attr.getValueAttr())
predList.emplace_back(pos, builder.getAttributeConstraint(value));
}
}
/// Collect all of the predicates for the given operand position.
@ -272,8 +274,27 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
// Push the constraint to the furthest position.
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
comparePosDepth);
PredicateBuilder::Predicate pred =
builder.getConstraint(op.getName(), allPositions, op.getIsNegated());
ResultRange results = op.getResults();
PredicateBuilder::Predicate pred = builder.getConstraint(
op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
op.getIsNegated());
// For each result register a position so it can be used later
for (auto [i, result] : llvm::enumerate(results)) {
ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
ConstraintPosition *pos = builder.getConstraintPosition(q, i);
auto [it, inserted] = inputs.try_emplace(result, pos);
// If this is an input value that has been visited in the tree, add a
// constraint to ensure that both instances refer to the same value.
if (!inserted) {
Position *first = pos;
Position *second = it->second;
if (comparePosDepth(second, first))
std::tie(second, first) = std::make_pair(first, second);
predList.emplace_back(second, builder.getEqualTo(first));
}
}
predList.emplace_back(pos, pred);
}
@ -875,6 +896,49 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
*root = std::make_unique<ExitNode>();
}
/// Sorts the range begin/end with the partial order given by cmp.
template <typename Iterator, typename Compare>
static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
while (begin != end) {
// Cannot compute sortBeforeOthers in the predicate of stable_partition
// because stable_partition will not keep the [begin, end) range intact
// while it runs.
llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
for (auto i = begin; i != end; ++i) {
if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
sortBeforeOthers.insert(*i);
}
auto const next = std::stable_partition(begin, end, [&](auto const &a) {
return sortBeforeOthers.contains(a);
});
assert(next != begin && "not a partial ordering");
begin = next;
}
}
/// Returns true if 'b' depends on a result of 'a'.
static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
if (!cqa)
return false;
auto positionDependsOnA = [&](Position *p) {
auto *cp = dyn_cast<ConstraintPosition>(p);
return cp && cp->getQuestion() == cqa;
};
if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
// Does any argument of b use a?
return llvm::any_of(cqb->getArgs(), positionDependsOnA);
}
if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
return positionDependsOnA(b->position) ||
positionDependsOnA(equalTo->getValue());
}
return positionDependsOnA(b->position);
}
/// Given a module containing PDL pattern operations, generate a matcher tree
/// using the patterns within the given module and return the root matcher node.
std::unique_ptr<MatcherNode>
@ -955,6 +1019,10 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
return *lhs < *rhs;
});
// Mostly keep the now established order, but also ensure that
// ConstraintQuestions come after the results they use.
stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
// Build the matchers for each of the pattern predicate lists.
std::unique_ptr<MatcherNode> root;
for (OrderedPredicateList &list : lists)

View File

@ -94,6 +94,12 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
LogicalResult ApplyNativeConstraintOp::verify() {
if (getNumOperands() == 0)
return emitOpError("expected at least one argument");
if (llvm::any_of(getResults(), [](OpResult result) {
return isa<OperationType>(result.getType());
})) {
return emitOpError(
"returning an operation from a constraint is not supported");
}
return success();
}

View File

@ -769,11 +769,25 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
void Generator::generate(pdl_interp::ApplyConstraintOp op,
ByteCodeWriter &writer) {
assert(constraintToMemIndex.count(op.getName()) &&
"expected index for constraint function");
// Constraints that should return a value have to be registered as rewrites.
// If a constraint and a rewrite of similar name are registered the
// constraint takes precedence
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
writer.appendPDLValueList(op.getArgs());
writer.append(ByteCodeField(op.getIsNegated()));
ResultRange results = op.getResults();
writer.append(ByteCodeField(results.size()));
for (Value result : results) {
// We record the expected kind of the result, so that we can provide extra
// verification of the native rewrite function and handle the failure case
// of constraints accordingly.
writer.appendPDLValueKind(result);
// Range results also need to append the range storage index.
if (isa<pdl::RangeType>(result.getType()))
writer.append(getRangeStorageIndex(result));
writer.append(result);
}
writer.append(op.getSuccessors());
}
void Generator::generate(pdl_interp::ApplyRewriteOp op,
@ -786,11 +800,9 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
ResultRange results = op.getResults();
writer.append(ByteCodeField(results.size()));
for (Value result : results) {
// In debug mode we also record the expected kind of the result, so that we
// We record the expected kind of the result, so that we
// can provide extra verification of the native rewrite function.
#ifndef NDEBUG
writer.appendPDLValueKind(result);
#endif
// Range results also need to append the range storage index.
if (isa<pdl::RangeType>(result.getType()))
@ -1076,6 +1088,28 @@ void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
// ByteCode Execution
namespace {
/// This class is an instantiation of the PDLResultList that provides access to
/// the returned results. This API is not on `PDLResultList` to avoid
/// overexposing access to information specific solely to the ByteCode.
class ByteCodeRewriteResultList : public PDLResultList {
public:
ByteCodeRewriteResultList(unsigned maxNumResults)
: PDLResultList(maxNumResults) {}
/// Return the list of PDL results.
MutableArrayRef<PDLValue> getResults() { return results; }
/// Return the type ranges allocated by this list.
MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
return allocatedTypeRanges;
}
/// Return the value ranges allocated by this list.
MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
return allocatedValueRanges;
}
};
/// This class provides support for executing a bytecode stream.
class ByteCodeExecutor {
public:
@ -1152,6 +1186,9 @@ private:
void executeSwitchResultCount();
void executeSwitchType();
void executeSwitchTypes();
void processNativeFunResults(ByteCodeRewriteResultList &results,
unsigned numResults,
LogicalResult &rewriteResult);
/// Pushes a code iterator to the stack.
void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
@ -1225,6 +1262,8 @@ private:
return T::getFromOpaquePointer(pointer);
}
void skip(size_t skipN) { curCodeIt += skipN; }
/// Jump to a specific successor based on a predicate value.
void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
/// Jump to a specific successor based on a destination index.
@ -1381,33 +1420,11 @@ private:
ArrayRef<PDLConstraintFunction> constraintFunctions;
ArrayRef<PDLRewriteFunction> rewriteFunctions;
};
/// This class is an instantiation of the PDLResultList that provides access to
/// the returned results. This API is not on `PDLResultList` to avoid
/// overexposing access to information specific solely to the ByteCode.
class ByteCodeRewriteResultList : public PDLResultList {
public:
ByteCodeRewriteResultList(unsigned maxNumResults)
: PDLResultList(maxNumResults) {}
/// Return the list of PDL results.
MutableArrayRef<PDLValue> getResults() { return results; }
/// Return the type ranges allocated by this list.
MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
return allocatedTypeRanges;
}
/// Return the value ranges allocated by this list.
MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
return allocatedValueRanges;
}
};
} // namespace
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
ByteCodeField fun_idx = read();
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);
@ -1422,8 +1439,29 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
llvm::dbgs() << " * isNegated: " << isNegated << "\n";
llvm::interleaveComma(args, llvm::dbgs());
});
// Invoke the constraint and jump to the proper destination.
selectJump(isNegated != succeeded(constraintFn(rewriter, args)));
ByteCodeField numResults = read();
const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
ByteCodeRewriteResultList results(numResults);
LogicalResult rewriteResult = constraintFn(rewriter, results, args);
ArrayRef<PDLValue> constraintResults = results.getResults();
LLVM_DEBUG({
if (succeeded(rewriteResult)) {
llvm::dbgs() << " * Constraint succeeded\n";
llvm::dbgs() << " * Results: ";
llvm::interleaveComma(constraintResults, llvm::dbgs());
llvm::dbgs() << "\n";
} else {
llvm::dbgs() << " * Constraint failed\n";
}
});
assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
"native PDL rewrite function succeeded but returned "
"unexpected number of results");
processNativeFunResults(results, numResults, rewriteResult);
// Depending on the constraint jump to the proper destination.
selectJump(isNegated != succeeded(rewriteResult));
}
LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
@ -1445,16 +1483,39 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");
// Store the results in the bytecode memory.
for (PDLValue &result : results.getResults()) {
processNativeFunResults(results, numResults, rewriteResult);
if (failed(rewriteResult)) {
LLVM_DEBUG(llvm::dbgs() << " - Failed");
return failure();
}
return success();
}
void ByteCodeExecutor::processNativeFunResults(
ByteCodeRewriteResultList &results, unsigned numResults,
LogicalResult &rewriteResult) {
// Store the results in the bytecode memory or handle missing results on
// failure.
for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
PDLValue::Kind resultKind = read<PDLValue::Kind>();
// Skip the according number of values on the buffer on failure and exit
// early as there are no results to process.
if (failed(rewriteResult)) {
if (resultKind == PDLValue::Kind::TypeRange ||
resultKind == PDLValue::Kind::ValueRange) {
skip(2);
} else {
skip(1);
}
return;
}
PDLValue result = results.getResults()[resultIdx];
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
// In debug mode we also verify the expected kind of the result.
#ifndef NDEBUG
assert(result.getKind() == read<PDLValue::Kind>() &&
"native PDL rewrite function returned an unexpected type of result");
#endif
assert(result.getKind() == resultKind &&
"native PDL rewrite function returned an unexpected type of "
"result");
// If the result is a range, we need to copy it over to the bytecodes
// range memory.
if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
@ -1476,13 +1537,6 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
allocatedTypeRangeMemory.push_back(std::move(it));
for (auto &it : results.getAllocatedValueRanges())
allocatedValueRangeMemory.push_back(std::move(it));
// Process the result of the rewrite.
if (failed(rewriteResult)) {
LLVM_DEBUG(llvm::dbgs() << " - Failed");
return failure();
}
return success();
}
void ByteCodeExecutor::executeAreEqual() {

View File

@ -1362,12 +1362,6 @@ FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
if (failed(parseToken(Token::semicolon,
"expected `;` after native declaration")))
return failure();
// TODO: PDL should be able to support constraint results in certain
// situations, we should revise this.
if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
return emitError(
"native Constraints currently do not support returning results");
}
return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
}

View File

@ -79,6 +79,57 @@ module @constraints {
// -----
// CHECK-LABEL: module @constraint_with_result
module @constraint_with_result {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
pdl.pattern : benefit(1) {
%root = operation
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
rewrite %root with "rewriter"(%attr : !pdl.attribute)
}
}
// -----
// CHECK-LABEL: module @constraint_with_unused_result
module @constraint_with_unused_result {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation)
pdl.pattern : benefit(1) {
%root = operation
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
rewrite %root with "rewriter"
}
}
// -----
// CHECK-LABEL: module @constraint_with_result_multiple
module @constraint_with_result_multiple {
// check that native constraints work as expected even when multiple identical constraints are fused
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
// CHECK-NOT: pdl_interp.apply_constraint "check_op_and_get_attr_constr"
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter_0(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
pdl.pattern : benefit(1) {
%root = operation
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
rewrite %root with "rewriter"(%attr : !pdl.attribute)
}
pdl.pattern : benefit(1) {
%root = operation
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
rewrite %root with "rewriter"(%attr : !pdl.attribute)
}
}
// -----
// CHECK-LABEL: module @negated_constraint
module @negated_constraint {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)

View File

@ -0,0 +1,77 @@
// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s
// Ensuse that the dependency between add & less
// causes them to be in the correct order.
// CHECK-LABEL: matcher
// CHECK: apply_constraint "return_attr_constraint"
// CHECK: apply_constraint "use_attr_constraint"
module {
pdl.pattern : benefit(1) {
%0 = attribute
%1 = types
%2 = operation "tosa.mul" {"shift" = %0} -> (%1 : !pdl.range<type>)
%3 = attribute = 0 : i32
%4 = attribute = 1 : i32
%5 = apply_native_constraint "return_attr_constraint"(%3, %4 : !pdl.attribute, !pdl.attribute) : !pdl.attribute
apply_native_constraint "use_attr_constraint"(%0, %5 : !pdl.attribute, !pdl.attribute)
rewrite %2 with "rewriter"
}
}
// -----
// CHECK-LABEL: matcher
// CHECK: %[[ATTR:.*]] = pdl_interp.get_attribute "attr" of
// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_attr_constraint"
// CHECK: pdl_interp.are_equal %[[ATTR:.*]], %[[CONSTRAINT:.*]]
pdl.pattern : benefit(1) {
%inputOp = operation
%result = result 0 of %inputOp
%attr = pdl.apply_native_constraint "return_attr_constraint"(%inputOp : !pdl.operation) : !pdl.attribute
%root = operation(%result : !pdl.value) {"attr" = %attr}
rewrite %root with "rewriter"(%attr : !pdl.attribute)
}
// -----
// CHECK-LABEL: matcher
// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_value_constr"
// CHECK: %[[VALUE:.*]] = pdl_interp.get_operand 0
// CHECK: pdl_interp.are_equal %[[VALUE:.*]], %[[CONSTRAINT:.*]]
pdl.pattern : benefit(1) {
%attr = attribute = 10
%value = pdl.apply_native_constraint "return_value_constr"(%attr: !pdl.attribute) : !pdl.value
%root = operation(%value : !pdl.value)
rewrite %root with "rewriter"
}
// -----
// CHECK-LABEL: matcher
// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_constr"
// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of
// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]]
pdl.pattern : benefit(1) {
%attr = attribute = 10
%type = pdl.apply_native_constraint "return_type_constr"(%attr: !pdl.attribute) : !pdl.type
%root = operation -> (%type : !pdl.type)
rewrite %root with "rewriter"
}
// -----
// CHECK-LABEL: matcher
// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_range_constr"
// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of
// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]]
pdl.pattern : benefit(1) {
%attr = attribute = 10
%types = pdl.apply_native_constraint "return_type_range_constr"(%attr: !pdl.attribute) : !pdl.range<type>
%root = operation -> (%types : !pdl.range<type>)
rewrite %root with "rewriter"
}

View File

@ -134,6 +134,24 @@ pdl.pattern @apply_rewrite_with_no_results : benefit(1) {
// -----
pdl.pattern @apply_constraint_with_no_results : benefit(1) {
%root = operation
apply_native_constraint "NativeConstraint"(%root : !pdl.operation)
rewrite %root with "rewriter"
}
// -----
pdl.pattern @apply_constraint_with_results : benefit(1) {
%root = operation
%attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute
rewrite %root {
apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute)
}
}
// -----
pdl.pattern @attribute_with_dict : benefit(1) {
%root = operation
rewrite %root {

View File

@ -109,6 +109,74 @@ module @ir attributes { test.apply_constraint_3 } {
// -----
// Test returning a type from a native constraint.
module @patterns {
pdl_interp.func @matcher(%root : !pdl.operation) {
pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
^pat:
%new_type = pdl_interp.apply_constraint "op_constr_return_type"(%root : !pdl.operation) : !pdl.type -> ^pat2, ^end
^pat2:
pdl_interp.record_match @rewriters::@success(%root, %new_type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end
^end:
pdl_interp.finalize
}
module @rewriters {
pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) {
%op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type)
pdl_interp.erase %root
pdl_interp.finalize
}
}
}
// CHECK-LABEL: test.apply_constraint_4
// CHECK-NOT: "test.replaced_by_pattern"
// CHECK: "test.replaced_by_pattern"() : () -> f32
module @ir attributes { test.apply_constraint_4 } {
"test.failure_op"() : () -> ()
"test.success_op"() : () -> ()
}
// -----
// Test success and failure cases of native constraints with pdl.range results.
module @patterns {
pdl_interp.func @matcher(%root : !pdl.operation) {
pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
^pat:
%num_results = pdl_interp.create_attribute 2 : i32
%types = pdl_interp.apply_constraint "op_constr_return_type_range"(%root, %num_results : !pdl.operation, !pdl.attribute) : !pdl.range<type> -> ^pat1, ^end
^pat1:
pdl_interp.record_match @rewriters::@success(%root, %types : !pdl.operation, !pdl.range<type>) : benefit(1), loc([%root]) -> ^end
^end:
pdl_interp.finalize
}
module @rewriters {
pdl_interp.func @success(%root : !pdl.operation, %types : !pdl.range<type>) {
%op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%types : !pdl.range<type>)
pdl_interp.erase %root
pdl_interp.finalize
}
}
}
// CHECK-LABEL: test.apply_constraint_5
// CHECK-NOT: "test.replaced_by_pattern"
// CHECK: "test.replaced_by_pattern"() : () -> (f32, f32)
module @ir attributes { test.apply_constraint_5 } {
"test.failure_op"() : () -> ()
"test.success_op"() : () -> ()
}
// -----
//===----------------------------------------------------------------------===//
// pdl_interp::ApplyRewriteOp

View File

@ -887,7 +887,7 @@ public:
#include "TestTransformDialectExtensionTypes.cpp.inc"
>();
auto verboseConstraint = [](PatternRewriter &rewriter,
auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &,
ArrayRef<PDLValue> pdlValues) {
for (const PDLValue &pdlValue : pdlValues) {
if (Operation *op = pdlValue.dyn_cast<Operation *>()) {

View File

@ -30,6 +30,50 @@ static LogicalResult customMultiEntityVariadicConstraint(
return success();
}
// Custom constraint that returns a value if the op is named test.success_op
static LogicalResult customValueResultConstraint(PatternRewriter &rewriter,
PDLResultList &results,
ArrayRef<PDLValue> args) {
auto *op = args[0].cast<Operation *>();
if (op->getName().getStringRef() == "test.success_op") {
StringAttr customAttr = rewriter.getStringAttr("test.success");
results.push_back(customAttr);
return success();
}
return failure();
}
// Custom constraint that returns a type if the op is named test.success_op
static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
PDLResultList &results,
ArrayRef<PDLValue> args) {
auto *op = args[0].cast<Operation *>();
if (op->getName().getStringRef() == "test.success_op") {
results.push_back(rewriter.getF32Type());
return success();
}
return failure();
}
// Custom constraint that returns a type range of variable length if the op is
// named test.success_op
static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
PDLResultList &results,
ArrayRef<PDLValue> args) {
auto *op = args[0].cast<Operation *>();
int numTypes = args[1].cast<Attribute>().cast<IntegerAttr>().getInt();
if (op->getName().getStringRef() == "test.success_op") {
SmallVector<Type> types;
for (int i = 0; i < numTypes; i++) {
types.push_back(rewriter.getF32Type());
}
results.push_back(TypeRange(types));
return success();
}
return failure();
}
// Custom creator invoked from PDL.
static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
return rewriter.create(OperationState(op->getLoc(), "test.success"));
@ -102,6 +146,12 @@ struct TestPDLByteCodePass
customMultiEntityConstraint);
pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
customMultiEntityVariadicConstraint);
pdlPattern.registerConstraintFunction("op_constr_return_attr",
customValueResultConstraint);
pdlPattern.registerConstraintFunction("op_constr_return_type",
customTypeResultConstraint);
pdlPattern.registerConstraintFunction("op_constr_return_type_range",
customTypeRangeResultConstraint);
pdlPattern.registerRewriteFunction("creator", customCreate);
pdlPattern.registerRewriteFunction("var_creator",
customVariadicResultCreate);

View File

@ -158,8 +158,3 @@ Pattern {
// CHECK: expected `;` after native declaration
Constraint Foo() [{}]
// -----
// CHECK: native Constraints currently do not support returning results
Constraint Foo() -> Op;

View File

@ -12,6 +12,14 @@ Constraint Foo() [{ /* Native Code */ }];
// -----
// Test that native constraints support returning results.
// CHECK: Module
// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Attr>
Constraint Foo() -> Attr;
// -----
// CHECK: Module
// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
// CHECK: `Inputs`

View File

@ -298,6 +298,6 @@ def test_apply_native_constraint():
pattern = PatternOp(1)
with InsertionPoint(pattern.body):
resultType = TypeOp()
ApplyNativeConstraintOp("typeConstraint", args=[resultType])
ApplyNativeConstraintOp([], "typeConstraint", args=[resultType])
root = OperationOp(types=[resultType])
RewriteOp(root, name="rewrite")