Reapply "Reapply "[mlir-query] Add function extraction feature to mlir-query""

Fix ASAN by erasing the op extracted post printing.

This reverts commit 732a5cba8c739ed40a7280b5d74ca717910c2c4c.
This commit is contained in:
Jacques Pienaar 2024-03-03 05:56:56 -08:00
parent 5b4759f9fd
commit 58b44c8102
10 changed files with 209 additions and 17 deletions

View File

@ -37,8 +37,12 @@ enum class ErrorType {
None,
// Parser Errors
ParserChainedExprInvalidArg,
ParserChainedExprNoCloseParen,
ParserChainedExprNoOpenParen,
ParserFailedToBuildMatcher,
ParserInvalidToken,
ParserMalformedChainedExpr,
ParserNoCloseParen,
ParserNoCode,
ParserNoComma,
@ -50,9 +54,10 @@ enum class ErrorType {
// Registry Errors
RegistryMatcherNotFound,
RegistryNotBindable,
RegistryValueNotFound,
RegistryWrongArgCount,
RegistryWrongArgType
RegistryWrongArgType,
};
void addError(Diagnostics *error, SourceRange range, ErrorType errorType,

View File

@ -63,8 +63,15 @@ public:
bool match(Operation *op) const { return implementation->match(op); }
void setFunctionName(StringRef name) { functionName = name.str(); };
bool hasFunctionName() const { return !functionName.empty(); };
StringRef getFunctionName() const { return functionName; };
private:
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
std::string functionName;
};
} // namespace mlir::query::matcher

View File

@ -6,6 +6,7 @@ add_mlir_library(MLIRQuery
${MLIR_MAIN_INCLUDE_DIR}/mlir/Query
LINK_LIBS PUBLIC
MLIRFuncDialect
MLIRQueryMatcher
)

View File

@ -38,6 +38,8 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)";
case ErrorType::RegistryValueNotFound:
return "Value not found: $0";
case ErrorType::RegistryNotBindable:
return "Matcher does not support binding.";
case ErrorType::ParserStringError:
return "Error parsing string token: <$0>";
@ -57,6 +59,14 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
return "Unexpected end of code.";
case ErrorType::ParserOverloadedType:
return "Input value has unresolved overloaded type: $0";
case ErrorType::ParserMalformedChainedExpr:
return "Period not followed by valid chained call.";
case ErrorType::ParserChainedExprInvalidArg:
return "Missing/Invalid argument for the chained call.";
case ErrorType::ParserChainedExprNoCloseParen:
return "Missing ')' for the chained call.";
case ErrorType::ParserChainedExprNoOpenParen:
return "Missing '(' for the chained call.";
case ErrorType::ParserFailedToBuildMatcher:
return "Failed to build matcher: $0.";

View File

@ -26,12 +26,17 @@ struct Parser::TokenInfo {
text = newText;
}
// Known identifiers.
static const char *const ID_Extract;
llvm::StringRef text;
TokenKind kind = TokenKind::Eof;
SourceRange range;
VariantValue value;
};
const char *const Parser::TokenInfo::ID_Extract = "extract";
class Parser::CodeTokenizer {
public:
// Constructor with matcherCode and error
@ -298,6 +303,36 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
}
bool Parser::parseChainedExpression(std::string &argument) {
// Parse the parenthesized argument to .extract("foo")
// Note: EOF is handled inside the consume functions and would fail below when
// checking token kind.
const TokenInfo openToken = tokenizer->consumeNextToken();
const TokenInfo argumentToken = tokenizer->consumeNextTokenIgnoreNewlines();
const TokenInfo closeToken = tokenizer->consumeNextTokenIgnoreNewlines();
if (openToken.kind != TokenKind::OpenParen) {
error->addError(openToken.range, ErrorType::ParserChainedExprNoOpenParen);
return false;
}
if (argumentToken.kind != TokenKind::Literal ||
!argumentToken.value.isString()) {
error->addError(argumentToken.range,
ErrorType::ParserChainedExprInvalidArg);
return false;
}
if (closeToken.kind != TokenKind::CloseParen) {
error->addError(closeToken.range, ErrorType::ParserChainedExprNoCloseParen);
return false;
}
// If all checks passed, extract the argument and return true.
argument = argumentToken.value.getString();
return true;
}
// Parse the arguments of a matcher
bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
const TokenInfo &nameToken, TokenInfo &endToken) {
@ -364,13 +399,34 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
return false;
}
std::string functionName;
if (tokenizer->peekNextToken().kind == TokenKind::Period) {
tokenizer->consumeNextToken();
TokenInfo chainCallToken = tokenizer->consumeNextToken();
if (chainCallToken.kind == TokenKind::CodeCompletion) {
addCompletion(chainCallToken, MatcherCompletion("extract(\"", "extract"));
return false;
}
if (chainCallToken.kind != TokenKind::Ident ||
chainCallToken.text != TokenInfo::ID_Extract) {
error->addError(chainCallToken.range,
ErrorType::ParserMalformedChainedExpr);
return false;
}
if (chainCallToken.text == TokenInfo::ID_Extract &&
!parseChainedExpression(functionName))
return false;
}
if (!ctor)
return false;
// Merge the start and end infos.
SourceRange matcherRange = nameToken.range;
matcherRange.end = endToken.range.end;
VariantMatcher result =
sema->actOnMatcherExpression(*ctor, matcherRange, args, error);
VariantMatcher result = sema->actOnMatcherExpression(
*ctor, matcherRange, functionName, args, error);
if (result.isNull())
return false;
*value = result;
@ -470,9 +526,10 @@ Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
}
VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
Diagnostics *error) {
return RegistryManager::constructMatcher(ctor, nameRange, args, error);
MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
llvm::ArrayRef<ParserValue> args, Diagnostics *error) {
return RegistryManager::constructMatcher(ctor, nameRange, functionName, args,
error);
}
std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(

View File

@ -64,10 +64,9 @@ public:
// Process a matcher expression. The caller takes ownership of the Matcher
// object returned.
virtual VariantMatcher
actOnMatcherExpression(MatcherCtor ctor, SourceRange nameRange,
llvm::ArrayRef<ParserValue> args,
Diagnostics *error) = 0;
virtual VariantMatcher actOnMatcherExpression(
MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
llvm::ArrayRef<ParserValue> args, Diagnostics *error) = 0;
// Look up a matcher by name in the matcher name found by the parser.
virtual std::optional<MatcherCtor>
@ -93,10 +92,11 @@ public:
std::optional<MatcherCtor>
lookupMatcherCtor(llvm::StringRef matcherName) override;
VariantMatcher actOnMatcherExpression(MatcherCtor ctor,
SourceRange nameRange,
llvm::ArrayRef<ParserValue> args,
Diagnostics *error) override;
VariantMatcher actOnMatcherExpression(MatcherCtor Ctor,
SourceRange NameRange,
StringRef functionName,
ArrayRef<ParserValue> Args,
Diagnostics *Error) override;
std::vector<ArgKind> getAcceptedCompletionTypes(
llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) override;
@ -153,6 +153,8 @@ private:
Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
const NamedValueMap *namedValues, Diagnostics *error);
bool parseChainedExpression(std::string &argument);
bool parseExpressionImpl(VariantValue *value);
bool parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,

View File

@ -132,8 +132,19 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
VariantMatcher RegistryManager::constructMatcher(
MatcherCtor ctor, internal::SourceRange nameRange,
llvm::ArrayRef<ParserValue> args, internal::Diagnostics *error) {
return ctor->create(nameRange, args, error);
llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
internal::Diagnostics *error) {
VariantMatcher out = ctor->create(nameRange, args, error);
if (functionName.empty() || out.isNull())
return out;
if (std::optional<DynMatcher> result = out.getDynMatcher()) {
result->setFunctionName(functionName);
return VariantMatcher::SingleMatcher(*result);
}
error->addError(nameRange, internal::ErrorType::RegistryNotBindable);
return {};
}
} // namespace mlir::query::matcher

View File

@ -61,6 +61,7 @@ public:
static VariantMatcher constructMatcher(MatcherCtor ctor,
internal::SourceRange nameRange,
llvm::StringRef functionName,
ArrayRef<ParserValue> args,
internal::Diagnostics *error);
};

View File

@ -8,6 +8,8 @@
#include "mlir/Query/Query.h"
#include "QueryParser.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Query/Matcher/MatchFinder.h"
#include "mlir/Query/QuerySession.h"
#include "mlir/Support/LogicalResult.h"
@ -34,6 +36,70 @@ static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
"\"" + binding + "\" binds here");
}
// TODO: Extract into a helper function that can be reused outside query
// context.
static Operation *extractFunction(std::vector<Operation *> &ops,
MLIRContext *context,
llvm::StringRef functionName) {
context->loadDialect<func::FuncDialect>();
OpBuilder builder(context);
// Collect data for function creation
std::vector<Operation *> slice;
std::vector<Value> values;
std::vector<Type> outputTypes;
for (auto *op : ops) {
// Return op's operands are propagated, but the op itself isn't needed.
if (!isa<func::ReturnOp>(op))
slice.push_back(op);
// All results are returned by the extracted function.
outputTypes.insert(outputTypes.end(), op->getResults().getTypes().begin(),
op->getResults().getTypes().end());
// Track all values that need to be taken as input to function.
values.insert(values.end(), op->getOperands().begin(),
op->getOperands().end());
}
// Create the function
FunctionType funcType =
builder.getFunctionType(ValueRange(values), outputTypes);
auto loc = builder.getUnknownLoc();
func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
builder.setInsertionPointToEnd(funcOp.addEntryBlock());
// Map original values to function arguments
IRMapping mapper;
for (const auto &arg : llvm::enumerate(values))
mapper.map(arg.value(), funcOp.getArgument(arg.index()));
// Clone operations and build function body
std::vector<Operation *> clonedOps;
std::vector<Value> clonedVals;
for (Operation *slicedOp : slice) {
Operation *clonedOp =
clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
clonedOp->result_end());
}
// Add return operation
builder.create<func::ReturnOp>(loc, clonedVals);
// Remove unused function arguments
size_t currentIndex = 0;
while (currentIndex < funcOp.getNumArguments()) {
if (funcOp.getArgument(currentIndex).use_empty())
funcOp.eraseArgument(currentIndex);
else
++currentIndex;
}
return funcOp;
}
Query::~Query() = default;
mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os,
@ -65,9 +131,22 @@ mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os,
mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os,
QuerySession &qs) const {
Operation *rootOp = qs.getRootOp();
int matchCount = 0;
std::vector<Operation *> matches =
matcher::MatchFinder().getMatches(qs.getRootOp(), matcher);
matcher::MatchFinder().getMatches(rootOp, matcher);
// An extract call is recognized by considering if the matcher has a name.
// TODO: Consider making the extract more explicit.
if (matcher.hasFunctionName()) {
auto functionName = matcher.getFunctionName();
Operation *function =
extractFunction(matches, rootOp->getContext(), functionName);
os << "\n" << *function << "\n\n";
function->erase();
return mlir::success();
}
os << "\n";
for (Operation *op : matches) {
os << "Match #" << ++matchCount << ":\n\n";

View File

@ -0,0 +1,19 @@
// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s
// CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) {
// CHECK: %[[MUL0:.*]] = arith.mulf {{.*}} : f32
// CHECK: %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
// CHECK: %[[MUL2:.*]] = arith.mulf {{.*}} : f32
// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
%sum0 = arith.addf %a, %b : f32
%sub0 = arith.subf %sum0, %c : f32
%mul0 = arith.mulf %a, %sub0 : f32
%sum1 = arith.addf %b, %c : f32
%mul1 = arith.mulf %sum1, %mul0 : f32
%sub2 = arith.subf %mul1, %a : f32
%sum2 = arith.addf %mul1, %b : f32
%mul2 = arith.mulf %sub2, %sum2 : f32
return %mul2 : f32
}