Reapply "Reapply "[mlir-query] Add function extraction feature to mlir-query""
Fix ASAN by erasing the op extracted post printing. This reverts commit 732a5cba8c739ed40a7280b5d74ca717910c2c4c.
This commit is contained in:
parent
5b4759f9fd
commit
58b44c8102
@ -37,8 +37,12 @@ enum class ErrorType {
|
||||
None,
|
||||
|
||||
// Parser Errors
|
||||
ParserChainedExprInvalidArg,
|
||||
ParserChainedExprNoCloseParen,
|
||||
ParserChainedExprNoOpenParen,
|
||||
ParserFailedToBuildMatcher,
|
||||
ParserInvalidToken,
|
||||
ParserMalformedChainedExpr,
|
||||
ParserNoCloseParen,
|
||||
ParserNoCode,
|
||||
ParserNoComma,
|
||||
@ -50,9 +54,10 @@ enum class ErrorType {
|
||||
|
||||
// Registry Errors
|
||||
RegistryMatcherNotFound,
|
||||
RegistryNotBindable,
|
||||
RegistryValueNotFound,
|
||||
RegistryWrongArgCount,
|
||||
RegistryWrongArgType
|
||||
RegistryWrongArgType,
|
||||
};
|
||||
|
||||
void addError(Diagnostics *error, SourceRange range, ErrorType errorType,
|
||||
|
@ -63,8 +63,15 @@ public:
|
||||
|
||||
bool match(Operation *op) const { return implementation->match(op); }
|
||||
|
||||
void setFunctionName(StringRef name) { functionName = name.str(); };
|
||||
|
||||
bool hasFunctionName() const { return !functionName.empty(); };
|
||||
|
||||
StringRef getFunctionName() const { return functionName; };
|
||||
|
||||
private:
|
||||
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
|
||||
std::string functionName;
|
||||
};
|
||||
|
||||
} // namespace mlir::query::matcher
|
||||
|
@ -6,6 +6,7 @@ add_mlir_library(MLIRQuery
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Query
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRFuncDialect
|
||||
MLIRQueryMatcher
|
||||
)
|
||||
|
||||
|
@ -38,6 +38,8 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
|
||||
return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)";
|
||||
case ErrorType::RegistryValueNotFound:
|
||||
return "Value not found: $0";
|
||||
case ErrorType::RegistryNotBindable:
|
||||
return "Matcher does not support binding.";
|
||||
|
||||
case ErrorType::ParserStringError:
|
||||
return "Error parsing string token: <$0>";
|
||||
@ -57,6 +59,14 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
|
||||
return "Unexpected end of code.";
|
||||
case ErrorType::ParserOverloadedType:
|
||||
return "Input value has unresolved overloaded type: $0";
|
||||
case ErrorType::ParserMalformedChainedExpr:
|
||||
return "Period not followed by valid chained call.";
|
||||
case ErrorType::ParserChainedExprInvalidArg:
|
||||
return "Missing/Invalid argument for the chained call.";
|
||||
case ErrorType::ParserChainedExprNoCloseParen:
|
||||
return "Missing ')' for the chained call.";
|
||||
case ErrorType::ParserChainedExprNoOpenParen:
|
||||
return "Missing '(' for the chained call.";
|
||||
case ErrorType::ParserFailedToBuildMatcher:
|
||||
return "Failed to build matcher: $0.";
|
||||
|
||||
|
@ -26,12 +26,17 @@ struct Parser::TokenInfo {
|
||||
text = newText;
|
||||
}
|
||||
|
||||
// Known identifiers.
|
||||
static const char *const ID_Extract;
|
||||
|
||||
llvm::StringRef text;
|
||||
TokenKind kind = TokenKind::Eof;
|
||||
SourceRange range;
|
||||
VariantValue value;
|
||||
};
|
||||
|
||||
const char *const Parser::TokenInfo::ID_Extract = "extract";
|
||||
|
||||
class Parser::CodeTokenizer {
|
||||
public:
|
||||
// Constructor with matcherCode and error
|
||||
@ -298,6 +303,36 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
|
||||
return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
|
||||
}
|
||||
|
||||
bool Parser::parseChainedExpression(std::string &argument) {
|
||||
// Parse the parenthesized argument to .extract("foo")
|
||||
// Note: EOF is handled inside the consume functions and would fail below when
|
||||
// checking token kind.
|
||||
const TokenInfo openToken = tokenizer->consumeNextToken();
|
||||
const TokenInfo argumentToken = tokenizer->consumeNextTokenIgnoreNewlines();
|
||||
const TokenInfo closeToken = tokenizer->consumeNextTokenIgnoreNewlines();
|
||||
|
||||
if (openToken.kind != TokenKind::OpenParen) {
|
||||
error->addError(openToken.range, ErrorType::ParserChainedExprNoOpenParen);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (argumentToken.kind != TokenKind::Literal ||
|
||||
!argumentToken.value.isString()) {
|
||||
error->addError(argumentToken.range,
|
||||
ErrorType::ParserChainedExprInvalidArg);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (closeToken.kind != TokenKind::CloseParen) {
|
||||
error->addError(closeToken.range, ErrorType::ParserChainedExprNoCloseParen);
|
||||
return false;
|
||||
}
|
||||
|
||||
// If all checks passed, extract the argument and return true.
|
||||
argument = argumentToken.value.getString();
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parse the arguments of a matcher
|
||||
bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
|
||||
const TokenInfo &nameToken, TokenInfo &endToken) {
|
||||
@ -364,13 +399,34 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string functionName;
|
||||
if (tokenizer->peekNextToken().kind == TokenKind::Period) {
|
||||
tokenizer->consumeNextToken();
|
||||
TokenInfo chainCallToken = tokenizer->consumeNextToken();
|
||||
if (chainCallToken.kind == TokenKind::CodeCompletion) {
|
||||
addCompletion(chainCallToken, MatcherCompletion("extract(\"", "extract"));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (chainCallToken.kind != TokenKind::Ident ||
|
||||
chainCallToken.text != TokenInfo::ID_Extract) {
|
||||
error->addError(chainCallToken.range,
|
||||
ErrorType::ParserMalformedChainedExpr);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (chainCallToken.text == TokenInfo::ID_Extract &&
|
||||
!parseChainedExpression(functionName))
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ctor)
|
||||
return false;
|
||||
// Merge the start and end infos.
|
||||
SourceRange matcherRange = nameToken.range;
|
||||
matcherRange.end = endToken.range.end;
|
||||
VariantMatcher result =
|
||||
sema->actOnMatcherExpression(*ctor, matcherRange, args, error);
|
||||
VariantMatcher result = sema->actOnMatcherExpression(
|
||||
*ctor, matcherRange, functionName, args, error);
|
||||
if (result.isNull())
|
||||
return false;
|
||||
*value = result;
|
||||
@ -470,9 +526,10 @@ Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
|
||||
}
|
||||
|
||||
VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
|
||||
MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
|
||||
Diagnostics *error) {
|
||||
return RegistryManager::constructMatcher(ctor, nameRange, args, error);
|
||||
MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
|
||||
llvm::ArrayRef<ParserValue> args, Diagnostics *error) {
|
||||
return RegistryManager::constructMatcher(ctor, nameRange, functionName, args,
|
||||
error);
|
||||
}
|
||||
|
||||
std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(
|
||||
|
@ -64,10 +64,9 @@ public:
|
||||
|
||||
// Process a matcher expression. The caller takes ownership of the Matcher
|
||||
// object returned.
|
||||
virtual VariantMatcher
|
||||
actOnMatcherExpression(MatcherCtor ctor, SourceRange nameRange,
|
||||
llvm::ArrayRef<ParserValue> args,
|
||||
Diagnostics *error) = 0;
|
||||
virtual VariantMatcher actOnMatcherExpression(
|
||||
MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
|
||||
llvm::ArrayRef<ParserValue> args, Diagnostics *error) = 0;
|
||||
|
||||
// Look up a matcher by name in the matcher name found by the parser.
|
||||
virtual std::optional<MatcherCtor>
|
||||
@ -93,10 +92,11 @@ public:
|
||||
std::optional<MatcherCtor>
|
||||
lookupMatcherCtor(llvm::StringRef matcherName) override;
|
||||
|
||||
VariantMatcher actOnMatcherExpression(MatcherCtor ctor,
|
||||
SourceRange nameRange,
|
||||
llvm::ArrayRef<ParserValue> args,
|
||||
Diagnostics *error) override;
|
||||
VariantMatcher actOnMatcherExpression(MatcherCtor Ctor,
|
||||
SourceRange NameRange,
|
||||
StringRef functionName,
|
||||
ArrayRef<ParserValue> Args,
|
||||
Diagnostics *Error) override;
|
||||
|
||||
std::vector<ArgKind> getAcceptedCompletionTypes(
|
||||
llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) override;
|
||||
@ -153,6 +153,8 @@ private:
|
||||
Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
|
||||
const NamedValueMap *namedValues, Diagnostics *error);
|
||||
|
||||
bool parseChainedExpression(std::string &argument);
|
||||
|
||||
bool parseExpressionImpl(VariantValue *value);
|
||||
|
||||
bool parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
|
||||
|
@ -132,8 +132,19 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
|
||||
|
||||
VariantMatcher RegistryManager::constructMatcher(
|
||||
MatcherCtor ctor, internal::SourceRange nameRange,
|
||||
llvm::ArrayRef<ParserValue> args, internal::Diagnostics *error) {
|
||||
return ctor->create(nameRange, args, error);
|
||||
llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
|
||||
internal::Diagnostics *error) {
|
||||
VariantMatcher out = ctor->create(nameRange, args, error);
|
||||
if (functionName.empty() || out.isNull())
|
||||
return out;
|
||||
|
||||
if (std::optional<DynMatcher> result = out.getDynMatcher()) {
|
||||
result->setFunctionName(functionName);
|
||||
return VariantMatcher::SingleMatcher(*result);
|
||||
}
|
||||
|
||||
error->addError(nameRange, internal::ErrorType::RegistryNotBindable);
|
||||
return {};
|
||||
}
|
||||
|
||||
} // namespace mlir::query::matcher
|
||||
|
@ -61,6 +61,7 @@ public:
|
||||
|
||||
static VariantMatcher constructMatcher(MatcherCtor ctor,
|
||||
internal::SourceRange nameRange,
|
||||
llvm::StringRef functionName,
|
||||
ArrayRef<ParserValue> args,
|
||||
internal::Diagnostics *error);
|
||||
};
|
||||
|
@ -8,6 +8,8 @@
|
||||
|
||||
#include "mlir/Query/Query.h"
|
||||
#include "QueryParser.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/Query/Matcher/MatchFinder.h"
|
||||
#include "mlir/Query/QuerySession.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
@ -34,6 +36,70 @@ static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
|
||||
"\"" + binding + "\" binds here");
|
||||
}
|
||||
|
||||
// TODO: Extract into a helper function that can be reused outside query
|
||||
// context.
|
||||
static Operation *extractFunction(std::vector<Operation *> &ops,
|
||||
MLIRContext *context,
|
||||
llvm::StringRef functionName) {
|
||||
context->loadDialect<func::FuncDialect>();
|
||||
OpBuilder builder(context);
|
||||
|
||||
// Collect data for function creation
|
||||
std::vector<Operation *> slice;
|
||||
std::vector<Value> values;
|
||||
std::vector<Type> outputTypes;
|
||||
|
||||
for (auto *op : ops) {
|
||||
// Return op's operands are propagated, but the op itself isn't needed.
|
||||
if (!isa<func::ReturnOp>(op))
|
||||
slice.push_back(op);
|
||||
|
||||
// All results are returned by the extracted function.
|
||||
outputTypes.insert(outputTypes.end(), op->getResults().getTypes().begin(),
|
||||
op->getResults().getTypes().end());
|
||||
|
||||
// Track all values that need to be taken as input to function.
|
||||
values.insert(values.end(), op->getOperands().begin(),
|
||||
op->getOperands().end());
|
||||
}
|
||||
|
||||
// Create the function
|
||||
FunctionType funcType =
|
||||
builder.getFunctionType(ValueRange(values), outputTypes);
|
||||
auto loc = builder.getUnknownLoc();
|
||||
func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
|
||||
|
||||
builder.setInsertionPointToEnd(funcOp.addEntryBlock());
|
||||
|
||||
// Map original values to function arguments
|
||||
IRMapping mapper;
|
||||
for (const auto &arg : llvm::enumerate(values))
|
||||
mapper.map(arg.value(), funcOp.getArgument(arg.index()));
|
||||
|
||||
// Clone operations and build function body
|
||||
std::vector<Operation *> clonedOps;
|
||||
std::vector<Value> clonedVals;
|
||||
for (Operation *slicedOp : slice) {
|
||||
Operation *clonedOp =
|
||||
clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
|
||||
clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
|
||||
clonedOp->result_end());
|
||||
}
|
||||
// Add return operation
|
||||
builder.create<func::ReturnOp>(loc, clonedVals);
|
||||
|
||||
// Remove unused function arguments
|
||||
size_t currentIndex = 0;
|
||||
while (currentIndex < funcOp.getNumArguments()) {
|
||||
if (funcOp.getArgument(currentIndex).use_empty())
|
||||
funcOp.eraseArgument(currentIndex);
|
||||
else
|
||||
++currentIndex;
|
||||
}
|
||||
|
||||
return funcOp;
|
||||
}
|
||||
|
||||
Query::~Query() = default;
|
||||
|
||||
mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os,
|
||||
@ -65,9 +131,22 @@ mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os,
|
||||
|
||||
mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os,
|
||||
QuerySession &qs) const {
|
||||
Operation *rootOp = qs.getRootOp();
|
||||
int matchCount = 0;
|
||||
std::vector<Operation *> matches =
|
||||
matcher::MatchFinder().getMatches(qs.getRootOp(), matcher);
|
||||
matcher::MatchFinder().getMatches(rootOp, matcher);
|
||||
|
||||
// An extract call is recognized by considering if the matcher has a name.
|
||||
// TODO: Consider making the extract more explicit.
|
||||
if (matcher.hasFunctionName()) {
|
||||
auto functionName = matcher.getFunctionName();
|
||||
Operation *function =
|
||||
extractFunction(matches, rootOp->getContext(), functionName);
|
||||
os << "\n" << *function << "\n\n";
|
||||
function->erase();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
os << "\n";
|
||||
for (Operation *op : matches) {
|
||||
os << "Match #" << ++matchCount << ":\n\n";
|
||||
|
19
mlir/test/mlir-query/function-extraction.mlir
Normal file
19
mlir/test/mlir-query/function-extraction.mlir
Normal file
@ -0,0 +1,19 @@
|
||||
// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s
|
||||
|
||||
// CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) {
|
||||
// CHECK: %[[MUL0:.*]] = arith.mulf {{.*}} : f32
|
||||
// CHECK: %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
|
||||
// CHECK: %[[MUL2:.*]] = arith.mulf {{.*}} : f32
|
||||
// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
|
||||
|
||||
func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
|
||||
%sum0 = arith.addf %a, %b : f32
|
||||
%sub0 = arith.subf %sum0, %c : f32
|
||||
%mul0 = arith.mulf %a, %sub0 : f32
|
||||
%sum1 = arith.addf %b, %c : f32
|
||||
%mul1 = arith.mulf %sum1, %mul0 : f32
|
||||
%sub2 = arith.subf %mul1, %a : f32
|
||||
%sum2 = arith.addf %mul1, %b : f32
|
||||
%mul2 = arith.mulf %sub2, %sum2 : f32
|
||||
return %mul2 : f32
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user