[mlir][bufferization] Allow mixed static/dynamic shapes in materialize_in_destination
op (#92681)
This commit relaxes the verifier of `bufferization.materialize_in_destination` such that mixed static/dynamic dimensions are allowed for the source and destination operands. E.g., `tensor<5xf32>` and `tensor<?xf32>` are now compatible, but it is assumed that the dynamic dimension is `5` at runtime. This commit fixes #91265.
This commit is contained in:
parent
cd676e5b27
commit
9d4b20a44e
@ -217,8 +217,7 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
|
||||
|
||||
def Bufferization_MaterializeInDestinationOp
|
||||
: Bufferization_Op<"materialize_in_destination",
|
||||
[AllShapesMatch<["source", "dest"]>,
|
||||
AllElementTypesMatch<["source", "dest"]>,
|
||||
[AllElementTypesMatch<["source", "dest"]>,
|
||||
BufferizableOpInterface, DestinationStyleOpInterface,
|
||||
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
|
||||
DeclareOpInterfaceMethods<SubsetOpInterface,
|
||||
@ -239,9 +238,9 @@ def Bufferization_MaterializeInDestinationOp
|
||||
memref, `source` materializes in `dest`, which is already a buffer. The op
|
||||
has no results in that case.
|
||||
|
||||
`source`, `dest` and `result` (if present) must have the same shape and
|
||||
element type. If the op has a result, the types of `result` and `dest` must
|
||||
match exactly (e.g., including any tensor encodings).
|
||||
`source`, `dest` and `result` (if present) must have the same runtime shape
|
||||
and element type. If the op has a result, the types of `result` and `dest`
|
||||
must match exactly (e.g., including any tensor encodings).
|
||||
|
||||
By default, this op bufferizes to a memcpy from the future buffer of the
|
||||
`source` tensor to the future buffer of the `dest` tensor or to the `dest`
|
||||
|
@ -686,6 +686,24 @@ LogicalResult MaterializeInDestinationOp::verify() {
|
||||
if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
|
||||
return emitOpError("'writable' must be specified if and only if the "
|
||||
"destination is of memref type");
|
||||
TensorType srcType = getSource().getType();
|
||||
ShapedType destType = cast<ShapedType>(getDest().getType());
|
||||
if (srcType.hasRank() != destType.hasRank())
|
||||
return emitOpError("source/destination shapes are incompatible");
|
||||
if (srcType.hasRank()) {
|
||||
if (srcType.getRank() != destType.getRank())
|
||||
return emitOpError("rank mismatch between source and destination shape");
|
||||
for (auto [src, dest] :
|
||||
llvm::zip(srcType.getShape(), destType.getShape())) {
|
||||
if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
|
||||
// Cannot verify dynamic dimension size. Assume that that they match at
|
||||
// runtime.
|
||||
continue;
|
||||
}
|
||||
if (src != dest)
|
||||
return emitOpError("source/destination shapes are incompatible");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -43,9 +43,16 @@ func.func @invalid_writable_on_op() {
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
|
||||
// expected-error @below{{failed to verify that all of {source, dest} have same shape}}
|
||||
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
|
||||
func.func @invalid_materialize_in_destination(%arg0: tensor<4xf32>, %arg1: tensor<5xf32>) {
|
||||
// expected-error @below{{source/destination shapes are incompatible}}
|
||||
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<4xf32>, tensor<5xf32>) -> tensor<5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_materialize_in_destination(%arg0: tensor<5x5xf32>, %arg1: tensor<5xf32>) {
|
||||
// expected-error @below{{rank mismatch between source and destination shape}}
|
||||
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5x5xf32>, tensor<5xf32>) -> tensor<5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -59,12 +59,15 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_materialize_in_destination_op
|
||||
func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>)
|
||||
-> tensor<?xf32> {
|
||||
func.func @test_materialize_in_destination_op(
|
||||
%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>,
|
||||
%arg4: tensor<5xf32>) -> tensor<?xf32> {
|
||||
// CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%1 = bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, memref<?xf32, 3>) -> ()
|
||||
bufferization.materialize_in_destination %arg0 in restrict writable %arg2 : (tensor<?xf32>, memref<?xf32, 3>) -> ()
|
||||
// CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
|
||||
%2 = bufferization.materialize_in_destination %arg0 in %arg4 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user