[mlir][scf] Extend option to yield replacement for multiple results case (#93144)
This patch extends the functionality of yielding replacement for multiple results case and adds another optional argument called `yieldResultNumber` indicating which result(s) need yield. If not given, all of results will be yield by default.
This commit is contained in:
parent
4169338e75
commit
7ef08eacd5
@ -191,10 +191,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
|
||||
/// where `%0` had other uses as well. If not reconstructed from within the loop
|
||||
/// body, uses of `%0` could not be replaced, making it still live and the
|
||||
/// fusion immaterial.
|
||||
///
|
||||
/// The @param `yieldResultNumber` decides which result would be yield. If not
|
||||
/// given, yield all `opResult` of fused producer.
|
||||
LogicalResult yieldReplacementForFusedProducer(
|
||||
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
|
||||
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
|
||||
MutableArrayRef<LoopLikeOpInterface> loops);
|
||||
MutableArrayRef<LoopLikeOpInterface> loops,
|
||||
ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{});
|
||||
|
||||
/// Transformation information returned after tile and fuse.
|
||||
struct SCFTileAndFuseResult {
|
||||
|
@ -51,7 +51,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
|
||||
For an operation to be "tiled and fused" with its (already tiled) consumer,
|
||||
an operation has to implement the following additional method (see
|
||||
description below):
|
||||
- `generateResultTileValue
|
||||
- `generateResultTileValue`
|
||||
- `getIterationDomainTileFromResultTile`
|
||||
|
||||
For an operation to be "tiled and fused" with its (already tiled) producer,
|
||||
an operation has to implement the following additional methods (see
|
||||
@ -302,6 +303,41 @@ def TilingInterface : OpInterface<"TilingInterface"> {
|
||||
return failure();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Method to return the tile of the iteration domain based
|
||||
on the given tile of the certain result.
|
||||
|
||||
This method is required to allow operations to be "tiled and fused"
|
||||
with an (already tiled) consumer. Given a tile of an result,
|
||||
returns the tile of the iteration space that uses this tile.
|
||||
- `resultNumber` is the result of the producer used by the consumer.
|
||||
- `offsets` is the offset of the slice of the producer result used by
|
||||
the tiled implementation of the consumer.
|
||||
- `sizes` is the size of the slice of the producer result used by the
|
||||
consumer.
|
||||
If fusion of the producer with the consumer is not legal for the
|
||||
result, or if this mapping cannot be computed, the implementation
|
||||
should return a failure.
|
||||
|
||||
For most cases `generateResultTileValue` could be a implemented using
|
||||
`getIterationDomainTileFromResultTile` + `getTiledImplementation`
|
||||
methods.
|
||||
}],
|
||||
/*retType=*/"::mlir::LogicalResult",
|
||||
/*methodName=*/"getIterationDomainTileFromResultTile",
|
||||
/*args=*/(ins
|
||||
"OpBuilder &":$b,
|
||||
"unsigned":$resultNumber,
|
||||
"ArrayRef<OpFoldResult> ":$offsets,
|
||||
"ArrayRef<OpFoldResult> ":$sizes,
|
||||
"SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
|
||||
"SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return failure();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Generates the scalar implementation of the operation.
|
||||
|
@ -215,10 +215,11 @@ struct LinalgOpTilingInterface
|
||||
return success();
|
||||
}
|
||||
|
||||
FailureOr<TilingResult>
|
||||
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes) const {
|
||||
LogicalResult getIterationDomainTileFromResultTile(
|
||||
Operation *op, OpBuilder &b, unsigned resultNumber,
|
||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
|
||||
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
|
||||
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
|
||||
auto linalgOp = cast<LinalgOp>(op);
|
||||
|
||||
// Check that the indexing map used for the output is a projected
|
||||
@ -232,9 +233,21 @@ struct LinalgOpTilingInterface
|
||||
"unhandled tiled implementation generation when result is not "
|
||||
"accessed using a permuted projection");
|
||||
}
|
||||
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
|
||||
|
||||
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
|
||||
mappedOffsets, mappedSizes);
|
||||
iterDomainOffsets, iterDomainSizes);
|
||||
return success();
|
||||
}
|
||||
|
||||
FailureOr<TilingResult>
|
||||
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes) const {
|
||||
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
|
||||
if (failed(getIterationDomainTileFromResultTile(
|
||||
op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
|
||||
return failure();
|
||||
}
|
||||
auto tilingInterfaceOp = cast<TilingInterface>(op);
|
||||
FailureOr<TilingResult> tilingResult =
|
||||
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
|
||||
|
@ -953,49 +953,122 @@ mlir::scf::tileAndFuseProducerOfSlice(
|
||||
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
|
||||
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
|
||||
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
|
||||
MutableArrayRef<LoopLikeOpInterface> loops) {
|
||||
MutableArrayRef<LoopLikeOpInterface> loops,
|
||||
ArrayRef<unsigned> yieldResultNumber) {
|
||||
if (loops.empty())
|
||||
return success();
|
||||
|
||||
OpResult fusableProducer = fusedProducerInfo.origProducer;
|
||||
Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
|
||||
FailureOr<Value> initValue = tensor::getOrCreateDestination(
|
||||
rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
|
||||
if (succeeded(initValue)) {
|
||||
Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
|
||||
*tiledOwner = fusedProducerInfo.tiledOps[0];
|
||||
|
||||
YieldTiledValuesFn newYieldValuesFn =
|
||||
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
|
||||
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
|
||||
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
|
||||
SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
|
||||
-> LogicalResult {
|
||||
OpBuilder::InsertionGuard g(innerRewriter);
|
||||
if (auto tiledDestStyleOp =
|
||||
tiledAndFusedProducer
|
||||
.getDefiningOp<DestinationStyleOpInterface>()) {
|
||||
rewriter.setInsertionPoint(tiledDestStyleOp);
|
||||
Value newRegionArg = newRegionIterArgs.back();
|
||||
Location loc = originalOwner->getLoc();
|
||||
// a. collect all init Value to be appended
|
||||
SmallVector<unsigned> initNumberList =
|
||||
yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
|
||||
0, originalOwner->getNumResults()))
|
||||
: llvm::to_vector(yieldResultNumber);
|
||||
SmallVector<Value> initValueList;
|
||||
for (const auto &resultNumber : initNumberList) {
|
||||
FailureOr<Value> initValue = tensor::getOrCreateDestination(
|
||||
rewriter, loc, originalOwner->getResult(resultNumber));
|
||||
if (succeeded(initValue)) {
|
||||
initValueList.push_back(initValue.value());
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
YieldTiledValuesFn newYieldValuesFn =
|
||||
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
|
||||
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
|
||||
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
|
||||
SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
|
||||
OpBuilder::InsertionGuard g(innerRewriter);
|
||||
|
||||
// get sliceOp tile information
|
||||
SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
|
||||
sliceSizes = sliceOp.getMixedSizes();
|
||||
|
||||
// expect all strides of sliceOp being 1
|
||||
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
|
||||
return !isConstantIntValue(ofr, 1);
|
||||
}))
|
||||
return failure();
|
||||
|
||||
unsigned sliceResultNumber =
|
||||
fusedProducerInfo.origProducer.getResultNumber();
|
||||
|
||||
auto tilableOp = cast<TilingInterface>(originalOwner);
|
||||
// b. get iterDomain Offset and Sizes based on sliceOp tile
|
||||
SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
|
||||
// skip tensor.pack/unpack/pad, which expects single opResult
|
||||
if (tilableOp->getNumResults() > 1 &&
|
||||
failed(tilableOp.getIterationDomainTileFromResultTile(
|
||||
rewriter, sliceResultNumber, sliceOffset, sliceSizes,
|
||||
iterDomainOffset, iterDomainSizes))) {
|
||||
// In theory, it is unnecessary to raise an error here. Actually although
|
||||
// it fails to reconstruct the result tensor, it should not broke current
|
||||
// fusion anyway. The reason why we must return failure currently is that
|
||||
// the callback function `newYieldValuesFn` will be called after new init
|
||||
// operand(s) has already been appended. It will take more refactoring to
|
||||
// make sure the init operands are added consistently in the future. For
|
||||
// more details, please refer to:
|
||||
// https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
|
||||
return failure();
|
||||
}
|
||||
|
||||
// c. calculate offsets and sizes info of all OpResults respectively based
|
||||
// on iteration Domain Tile
|
||||
SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
|
||||
for (const auto &resultNumber : initNumberList) {
|
||||
if (resultNumber == sliceResultNumber) {
|
||||
offsetList.push_back(sliceOffset);
|
||||
sizesList.push_back(sliceSizes);
|
||||
} else {
|
||||
assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
|
||||
// infer result tile according to the iteration domain tile
|
||||
SmallVector<OpFoldResult> offset, sizes;
|
||||
if (failed(tilableOp.getResultTilePosition(
|
||||
rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
|
||||
offset, sizes))) {
|
||||
return failure();
|
||||
}
|
||||
offsetList.push_back(offset);
|
||||
sizesList.push_back(sizes);
|
||||
}
|
||||
}
|
||||
|
||||
// d. create `extract_slice` for `iter_args` for DPS operation if necessary
|
||||
if (auto tiledDestStyleOp =
|
||||
dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
|
||||
rewriter.setInsertionPoint(tiledDestStyleOp);
|
||||
for (const auto &&[index, newRegionArg] :
|
||||
llvm::enumerate(newRegionIterArgs)) {
|
||||
auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
|
||||
sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
|
||||
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
|
||||
unsigned resultNumber = fusableProducer.getResultNumber();
|
||||
loc, newRegionArg, offsetList[index], sizesList[index],
|
||||
SmallVector<OpFoldResult>(offsetList[index].size(),
|
||||
rewriter.getIndexAttr(1)));
|
||||
unsigned resultNumber = initNumberList[index];
|
||||
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
|
||||
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
|
||||
});
|
||||
}
|
||||
Block *block = rewriter.getInsertionPoint()->getBlock();
|
||||
rewriter.setInsertionPoint(block->getTerminator());
|
||||
tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
|
||||
tiledOffset.emplace_back(sliceOp.getMixedOffsets());
|
||||
tiledSizes.emplace_back(sliceOp.getMixedSizes());
|
||||
return success();
|
||||
};
|
||||
}
|
||||
|
||||
return addInitOperandsToLoopNest(rewriter, loops,
|
||||
SmallVector<Value>{initValue.value()},
|
||||
newYieldValuesFn);
|
||||
}
|
||||
return success();
|
||||
// e. prepare tiled offset and sizes for later `insert_slice` creation by
|
||||
// caller
|
||||
Block *block = rewriter.getInsertionPoint()->getBlock();
|
||||
rewriter.setInsertionPoint(block->getTerminator());
|
||||
for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
|
||||
tiledResult.push_back(tiledOwner->getResult(resultNumber));
|
||||
tiledOffset.emplace_back(offsetList[index]);
|
||||
tiledSizes.emplace_back(sizesList[index]);
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
return addInitOperandsToLoopNest(rewriter, loops, initValueList,
|
||||
newYieldValuesFn);
|
||||
}
|
||||
|
||||
/// Implementation of tile consumer and fuse producer greedily.
|
||||
@ -1085,14 +1158,22 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
|
||||
continue;
|
||||
|
||||
if (yieldReplacement) {
|
||||
// Reconstruct and yield all opResult of fusableProducerOp by default. The
|
||||
// caller can specific which one to yield by designating optional argument
|
||||
// named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
|
||||
Operation *fusableProducerOp = fusableProducer.getOwner();
|
||||
if (failed(yieldReplacementForFusedProducer(
|
||||
rewriter, candidateSliceOp, fusedResult.value(), loops))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
fusableProducer.getOwner(), "failed to replacement value for this "
|
||||
"oepration from within the tiled loop");
|
||||
fusableProducerOp, "failed to replacement value for this "
|
||||
"operation from within the tiled loop");
|
||||
}
|
||||
for (auto [index, result] :
|
||||
llvm::enumerate(fusableProducerOp->getResults())) {
|
||||
origValToResultNumber[result] = loops.front()->getNumResults() -
|
||||
fusableProducerOp->getNumResults() +
|
||||
index;
|
||||
}
|
||||
origValToResultNumber[fusableProducer] =
|
||||
loops.front()->getNumResults() - 1;
|
||||
}
|
||||
|
||||
if (Operation *tiledAndFusedOp =
|
||||
|
@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} {
|
||||
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
|
||||
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]]
|
||||
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0
|
||||
|
||||
// -----
|
||||
|
||||
func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>,
|
||||
%rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>,
|
||||
%rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>)
|
||||
-> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) {
|
||||
%out0, %out1 = linalg.generic {
|
||||
indexing_maps = [affine_map<(i, j) -> (i, j)>,
|
||||
affine_map<(i, j) -> (i, j)>,
|
||||
affine_map<(i, j) -> (i, j)>,
|
||||
affine_map<(i, j) -> (j, i)>],
|
||||
iterator_types = ["parallel", "parallel"]
|
||||
}
|
||||
ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>)
|
||||
outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) {
|
||||
^bb0(%0: f32, %1: f32, %2: f32, %3: f32):
|
||||
%4 = arith.mulf %0, %1 : f32
|
||||
%5 = arith.addf %0, %1 : f32
|
||||
linalg.yield %4, %5: f32, f32
|
||||
} -> (tensor<32x32xf32>, tensor<32x32xf32>)
|
||||
|
||||
%out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32>
|
||||
|
||||
return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
|
||||
%add = transform.structured.match ops{["linalg.add"]} in %arg0
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%a, %b = transform.test.fuse_and_yield %add [16]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
// CHECK: func.func @multiple_outputs_fusion_yield_all(
|
||||
// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
|
||||
// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
|
||||
// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
|
||||
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
|
||||
// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
|
||||
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>)
|
||||
// CHECK: %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] =
|
||||
// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
|
||||
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
|
||||
// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0]
|
||||
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
|
||||
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]]
|
||||
// CHECK: %[[GENERIC_TILE:.+]]:2 = linalg.generic
|
||||
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
|
||||
// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
|
||||
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0]
|
||||
// CHECK-DAG: %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
|
||||
// CHECK: %[[ADD_TILE:.+]] = linalg.add
|
||||
// CHECK-SAME: ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] :
|
||||
// CHECK-SAME: outs(%[[INIT2_TILE]] :
|
||||
// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0]
|
||||
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0]
|
||||
// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]]
|
||||
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]]
|
||||
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0
|
||||
|
Loading…
x
Reference in New Issue
Block a user