[mlir][scf] Extend option to yield replacement for multiple results case (#93144)

This patch extends the functionality of yielding replacement for multiple 
results case and adds another optional argument called `yieldResultNumber` 
indicating which result(s) need yield. If not given, all of results will be yield 
by default.
This commit is contained in:
Yun-Fly 2024-06-28 20:43:52 +08:00 committed by GitHub
parent 4169338e75
commit 7ef08eacd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 241 additions and 45 deletions

View File

@ -191,10 +191,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
/// where `%0` had other uses as well. If not reconstructed from within the loop
/// body, uses of `%0` could not be replaced, making it still live and the
/// fusion immaterial.
///
/// The @param `yieldResultNumber` decides which result would be yield. If not
/// given, yield all `opResult` of fused producer.
LogicalResult yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<LoopLikeOpInterface> loops);
MutableArrayRef<LoopLikeOpInterface> loops,
ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{});
/// Transformation information returned after tile and fuse.
struct SCFTileAndFuseResult {

View File

@ -51,7 +51,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
For an operation to be "tiled and fused" with its (already tiled) consumer,
an operation has to implement the following additional method (see
description below):
- `generateResultTileValue
- `generateResultTileValue`
- `getIterationDomainTileFromResultTile`
For an operation to be "tiled and fused" with its (already tiled) producer,
an operation has to implement the following additional methods (see
@ -302,6 +303,41 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
InterfaceMethod<
/*desc=*/[{
Method to return the tile of the iteration domain based
on the given tile of the certain result.
This method is required to allow operations to be "tiled and fused"
with an (already tiled) consumer. Given a tile of an result,
returns the tile of the iteration space that uses this tile.
- `resultNumber` is the result of the producer used by the consumer.
- `offsets` is the offset of the slice of the producer result used by
the tiled implementation of the consumer.
- `sizes` is the size of the slice of the producer result used by the
consumer.
If fusion of the producer with the consumer is not legal for the
result, or if this mapping cannot be computed, the implementation
should return a failure.
For most cases `generateResultTileValue` could be a implemented using
`getIterationDomainTileFromResultTile` + `getTiledImplementation`
methods.
}],
/*retType=*/"::mlir::LogicalResult",
/*methodName=*/"getIterationDomainTileFromResultTile",
/*args=*/(ins
"OpBuilder &":$b,
"unsigned":$resultNumber,
"ArrayRef<OpFoldResult> ":$offsets,
"ArrayRef<OpFoldResult> ":$sizes,
"SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
"SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
}]
>,
InterfaceMethod<
/*desc=*/[{
Generates the scalar implementation of the operation.

View File

@ -215,10 +215,11 @@ struct LinalgOpTilingInterface
return success();
}
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
LogicalResult getIterationDomainTileFromResultTile(
Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
auto linalgOp = cast<LinalgOp>(op);
// Check that the indexing map used for the output is a projected
@ -232,9 +233,21 @@ struct LinalgOpTilingInterface
"unhandled tiled implementation generation when result is not "
"accessed using a permuted projection");
}
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
mappedOffsets, mappedSizes);
iterDomainOffsets, iterDomainSizes);
return success();
}
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
if (failed(getIterationDomainTileFromResultTile(
op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
return failure();
}
auto tilingInterfaceOp = cast<TilingInterface>(op);
FailureOr<TilingResult> tilingResult =
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);

View File

@ -953,49 +953,122 @@ mlir::scf::tileAndFuseProducerOfSlice(
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<LoopLikeOpInterface> loops) {
MutableArrayRef<LoopLikeOpInterface> loops,
ArrayRef<unsigned> yieldResultNumber) {
if (loops.empty())
return success();
OpResult fusableProducer = fusedProducerInfo.origProducer;
Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
FailureOr<Value> initValue = tensor::getOrCreateDestination(
rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
if (succeeded(initValue)) {
Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
*tiledOwner = fusedProducerInfo.tiledOps[0];
YieldTiledValuesFn newYieldValuesFn =
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
-> LogicalResult {
OpBuilder::InsertionGuard g(innerRewriter);
if (auto tiledDestStyleOp =
tiledAndFusedProducer
.getDefiningOp<DestinationStyleOpInterface>()) {
rewriter.setInsertionPoint(tiledDestStyleOp);
Value newRegionArg = newRegionIterArgs.back();
Location loc = originalOwner->getLoc();
// a. collect all init Value to be appended
SmallVector<unsigned> initNumberList =
yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
0, originalOwner->getNumResults()))
: llvm::to_vector(yieldResultNumber);
SmallVector<Value> initValueList;
for (const auto &resultNumber : initNumberList) {
FailureOr<Value> initValue = tensor::getOrCreateDestination(
rewriter, loc, originalOwner->getResult(resultNumber));
if (succeeded(initValue)) {
initValueList.push_back(initValue.value());
} else {
return failure();
}
}
YieldTiledValuesFn newYieldValuesFn =
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
OpBuilder::InsertionGuard g(innerRewriter);
// get sliceOp tile information
SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
sliceSizes = sliceOp.getMixedSizes();
// expect all strides of sliceOp being 1
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
return !isConstantIntValue(ofr, 1);
}))
return failure();
unsigned sliceResultNumber =
fusedProducerInfo.origProducer.getResultNumber();
auto tilableOp = cast<TilingInterface>(originalOwner);
// b. get iterDomain Offset and Sizes based on sliceOp tile
SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
// skip tensor.pack/unpack/pad, which expects single opResult
if (tilableOp->getNumResults() > 1 &&
failed(tilableOp.getIterationDomainTileFromResultTile(
rewriter, sliceResultNumber, sliceOffset, sliceSizes,
iterDomainOffset, iterDomainSizes))) {
// In theory, it is unnecessary to raise an error here. Actually although
// it fails to reconstruct the result tensor, it should not broke current
// fusion anyway. The reason why we must return failure currently is that
// the callback function `newYieldValuesFn` will be called after new init
// operand(s) has already been appended. It will take more refactoring to
// make sure the init operands are added consistently in the future. For
// more details, please refer to:
// https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
return failure();
}
// c. calculate offsets and sizes info of all OpResults respectively based
// on iteration Domain Tile
SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
for (const auto &resultNumber : initNumberList) {
if (resultNumber == sliceResultNumber) {
offsetList.push_back(sliceOffset);
sizesList.push_back(sliceSizes);
} else {
assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
// infer result tile according to the iteration domain tile
SmallVector<OpFoldResult> offset, sizes;
if (failed(tilableOp.getResultTilePosition(
rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
offset, sizes))) {
return failure();
}
offsetList.push_back(offset);
sizesList.push_back(sizes);
}
}
// d. create `extract_slice` for `iter_args` for DPS operation if necessary
if (auto tiledDestStyleOp =
dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
rewriter.setInsertionPoint(tiledDestStyleOp);
for (const auto &&[index, newRegionArg] :
llvm::enumerate(newRegionIterArgs)) {
auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
unsigned resultNumber = fusableProducer.getResultNumber();
loc, newRegionArg, offsetList[index], sizesList[index],
SmallVector<OpFoldResult>(offsetList[index].size(),
rewriter.getIndexAttr(1)));
unsigned resultNumber = initNumberList[index];
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
});
}
Block *block = rewriter.getInsertionPoint()->getBlock();
rewriter.setInsertionPoint(block->getTerminator());
tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
tiledOffset.emplace_back(sliceOp.getMixedOffsets());
tiledSizes.emplace_back(sliceOp.getMixedSizes());
return success();
};
}
return addInitOperandsToLoopNest(rewriter, loops,
SmallVector<Value>{initValue.value()},
newYieldValuesFn);
}
return success();
// e. prepare tiled offset and sizes for later `insert_slice` creation by
// caller
Block *block = rewriter.getInsertionPoint()->getBlock();
rewriter.setInsertionPoint(block->getTerminator());
for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
tiledResult.push_back(tiledOwner->getResult(resultNumber));
tiledOffset.emplace_back(offsetList[index]);
tiledSizes.emplace_back(sizesList[index]);
}
return success();
};
return addInitOperandsToLoopNest(rewriter, loops, initValueList,
newYieldValuesFn);
}
/// Implementation of tile consumer and fuse producer greedily.
@ -1085,14 +1158,22 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
continue;
if (yieldReplacement) {
// Reconstruct and yield all opResult of fusableProducerOp by default. The
// caller can specific which one to yield by designating optional argument
// named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
Operation *fusableProducerOp = fusableProducer.getOwner();
if (failed(yieldReplacementForFusedProducer(
rewriter, candidateSliceOp, fusedResult.value(), loops))) {
return rewriter.notifyMatchFailure(
fusableProducer.getOwner(), "failed to replacement value for this "
"oepration from within the tiled loop");
fusableProducerOp, "failed to replacement value for this "
"operation from within the tiled loop");
}
for (auto [index, result] :
llvm::enumerate(fusableProducerOp->getResults())) {
origValToResultNumber[result] = loops.front()->getNumResults() -
fusableProducerOp->getNumResults() +
index;
}
origValToResultNumber[fusableProducer] =
loops.front()->getNumResults() - 1;
}
if (Operation *tiledAndFusedOp =

View File

@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]]
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0
// -----
func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>,
%rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>,
%rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>)
-> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) {
%out0, %out1 = linalg.generic {
indexing_maps = [affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (j, i)>],
iterator_types = ["parallel", "parallel"]
}
ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>)
outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) {
^bb0(%0: f32, %1: f32, %2: f32, %3: f32):
%4 = arith.mulf %0, %1 : f32
%5 = arith.addf %0, %1 : f32
linalg.yield %4, %5: f32, f32
} -> (tensor<32x32xf32>, tensor<32x32xf32>)
%out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32>
return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%add = transform.structured.match ops{["linalg.add"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%a, %b = transform.test.fuse_and_yield %add [16]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// CHECK: func.func @multiple_outputs_fusion_yield_all(
// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>)
// CHECK: %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0]
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]]
// CHECK: %[[GENERIC_TILE:.+]]:2 = linalg.generic
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0]
// CHECK-DAG: %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
// CHECK: %[[ADD_TILE:.+]] = linalg.add
// CHECK-SAME: ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] :
// CHECK-SAME: outs(%[[INIT2_TILE]] :
// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0]
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0]
// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]]
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]]
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0