diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index 4f609ddff9a4..1c70a4b8df92 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -217,8 +217,7 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [ def Bufferization_MaterializeInDestinationOp : Bufferization_Op<"materialize_in_destination", - [AllShapesMatch<["source", "dest"]>, - AllElementTypesMatch<["source", "dest"]>, + [AllElementTypesMatch<["source", "dest"]>, BufferizableOpInterface, DestinationStyleOpInterface, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods(getDest().getType())) return emitOpError("'writable' must be specified if and only if the " "destination is of memref type"); + TensorType srcType = getSource().getType(); + ShapedType destType = cast(getDest().getType()); + if (srcType.hasRank() != destType.hasRank()) + return emitOpError("source/destination shapes are incompatible"); + if (srcType.hasRank()) { + if (srcType.getRank() != destType.getRank()) + return emitOpError("rank mismatch between source and destination shape"); + for (auto [src, dest] : + llvm::zip(srcType.getShape(), destType.getShape())) { + if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) { + // Cannot verify dynamic dimension size. Assume that that they match at + // runtime. + continue; + } + if (src != dest) + return emitOpError("source/destination shapes are incompatible"); + } + } return success(); } diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir index 4ebdb0a8f049..2c8807b66de7 100644 --- a/mlir/test/Dialect/Bufferization/invalid.mlir +++ b/mlir/test/Dialect/Bufferization/invalid.mlir @@ -43,9 +43,16 @@ func.func @invalid_writable_on_op() { // ----- -func.func @invalid_materialize_in_destination(%arg0: tensor, %arg1: tensor<5xf32>) { - // expected-error @below{{failed to verify that all of {source, dest} have same shape}} - bufferization.materialize_in_destination %arg0 in %arg1 : (tensor, tensor<5xf32>) -> tensor<5xf32> +func.func @invalid_materialize_in_destination(%arg0: tensor<4xf32>, %arg1: tensor<5xf32>) { + // expected-error @below{{source/destination shapes are incompatible}} + bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<4xf32>, tensor<5xf32>) -> tensor<5xf32> +} + +// ----- + +func.func @invalid_materialize_in_destination(%arg0: tensor<5x5xf32>, %arg1: tensor<5xf32>) { + // expected-error @below{{rank mismatch between source and destination shape}} + bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5x5xf32>, tensor<5xf32>) -> tensor<5xf32> } // ----- diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir index d4bda0632189..ad4a66c1b797 100644 --- a/mlir/test/Dialect/Bufferization/ops.mlir +++ b/mlir/test/Dialect/Bufferization/ops.mlir @@ -59,12 +59,15 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) { } // CHECK-LABEL: func @test_materialize_in_destination_op -func.func @test_materialize_in_destination_op(%arg0: tensor, %arg1: tensor, %arg2: memref) - -> tensor { +func.func @test_materialize_in_destination_op( + %arg0: tensor, %arg1: tensor, %arg2: memref, + %arg4: tensor<5xf32>) -> tensor { // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor, tensor) -> tensor %1 = bufferization.materialize_in_destination %arg0 in %arg1 : (tensor, tensor) -> tensor // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor, memref) -> () bufferization.materialize_in_destination %arg0 in restrict writable %arg2 : (tensor, memref) -> () + // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor, tensor<5xf32>) -> tensor<5xf32> + %2 = bufferization.materialize_in_destination %arg0 in %arg4 : (tensor, tensor<5xf32>) -> tensor<5xf32> return %1 : tensor }