[mlir][bufferization] Allow mixed static/dynamic shapes in materialize_in_destination op (#92681)

This commit relaxes the verifier of
`bufferization.materialize_in_destination` such that mixed
static/dynamic dimensions are allowed for the source and destination
operands. E.g., `tensor<5xf32>` and `tensor<?xf32>` are now compatible,
but it is assumed that the dynamic dimension is `5` at runtime.

This commit fixes #91265.
This commit is contained in:
Matthias Springer 2024-06-01 12:04:56 +02:00 committed by GitHub
parent cd676e5b27
commit 9d4b20a44e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 37 additions and 10 deletions

View File

@ -217,8 +217,7 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
def Bufferization_MaterializeInDestinationOp
: Bufferization_Op<"materialize_in_destination",
[AllShapesMatch<["source", "dest"]>,
AllElementTypesMatch<["source", "dest"]>,
[AllElementTypesMatch<["source", "dest"]>,
BufferizableOpInterface, DestinationStyleOpInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<SubsetOpInterface,
@ -239,9 +238,9 @@ def Bufferization_MaterializeInDestinationOp
memref, `source` materializes in `dest`, which is already a buffer. The op
has no results in that case.
`source`, `dest` and `result` (if present) must have the same shape and
element type. If the op has a result, the types of `result` and `dest` must
match exactly (e.g., including any tensor encodings).
`source`, `dest` and `result` (if present) must have the same runtime shape
and element type. If the op has a result, the types of `result` and `dest`
must match exactly (e.g., including any tensor encodings).
By default, this op bufferizes to a memcpy from the future buffer of the
`source` tensor to the future buffer of the `dest` tensor or to the `dest`

View File

@ -686,6 +686,24 @@ LogicalResult MaterializeInDestinationOp::verify() {
if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
return emitOpError("'writable' must be specified if and only if the "
"destination is of memref type");
TensorType srcType = getSource().getType();
ShapedType destType = cast<ShapedType>(getDest().getType());
if (srcType.hasRank() != destType.hasRank())
return emitOpError("source/destination shapes are incompatible");
if (srcType.hasRank()) {
if (srcType.getRank() != destType.getRank())
return emitOpError("rank mismatch between source and destination shape");
for (auto [src, dest] :
llvm::zip(srcType.getShape(), destType.getShape())) {
if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
// Cannot verify dynamic dimension size. Assume that that they match at
// runtime.
continue;
}
if (src != dest)
return emitOpError("source/destination shapes are incompatible");
}
}
return success();
}

View File

@ -43,9 +43,16 @@ func.func @invalid_writable_on_op() {
// -----
func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
// expected-error @below{{failed to verify that all of {source, dest} have same shape}}
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
func.func @invalid_materialize_in_destination(%arg0: tensor<4xf32>, %arg1: tensor<5xf32>) {
// expected-error @below{{source/destination shapes are incompatible}}
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<4xf32>, tensor<5xf32>) -> tensor<5xf32>
}
// -----
func.func @invalid_materialize_in_destination(%arg0: tensor<5x5xf32>, %arg1: tensor<5xf32>) {
// expected-error @below{{rank mismatch between source and destination shape}}
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5x5xf32>, tensor<5xf32>) -> tensor<5xf32>
}
// -----

View File

@ -59,12 +59,15 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
}
// CHECK-LABEL: func @test_materialize_in_destination_op
func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>)
-> tensor<?xf32> {
func.func @test_materialize_in_destination_op(
%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>,
%arg4: tensor<5xf32>) -> tensor<?xf32> {
// CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, memref<?xf32, 3>) -> ()
bufferization.materialize_in_destination %arg0 in restrict writable %arg2 : (tensor<?xf32>, memref<?xf32, 3>) -> ()
// CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
%2 = bufferization.materialize_in_destination %arg0 in %arg4 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
return %1 : tensor<?xf32>
}