[mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments (#97213)
This commit simplifies the handling of dropped arguments and updates some dialect conversion documentation that is outdated. When converting a block signature, a `BlockTypeConversionRewrite` object and potentially multiple `ReplaceBlockArgRewrite` are created. During the "commit" phase, uses of the old block arguments are replaced with the new block arguments, but the old implementation was written in an inconsistent way: some block arguments were replaced in `BlockTypeConversionRewrite::commit` and some were replaced in `ReplaceBlockArgRewrite::commit`. The new `BlockTypeConversionRewrite::commit` implementation is much simpler and no longer modifies any IR; that is done only in `ReplaceBlockArgRewrite` now. The `ConvertedArgInfo` data structure is no longer needed. To that end, materializations of dropped arguments are now built in `applySignatureConversion` instead of `materializeLiveConversions`; the latter function no longer has to deal with dropped arguments. Other minor improvements: - Add more comments to `applySignatureConversion`. Note: Error messages around failed materializations for dropped basic block arguments changed slightly. That is because those materializations are now built in `legalizeUnresolvedMaterialization` instead of `legalizeConvertedArgumentTypes`. This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion. This is a re-upload of #96207.
This commit is contained in:
parent
0d26f65414
commit
bbd4af5da2
@ -432,34 +432,14 @@ private:
|
||||
Block *insertBeforeBlock;
|
||||
};
|
||||
|
||||
/// This structure contains the information pertaining to an argument that has
|
||||
/// been converted.
|
||||
struct ConvertedArgInfo {
|
||||
ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
|
||||
Value castValue = nullptr)
|
||||
: newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
|
||||
|
||||
/// The start index of in the new argument list that contains arguments that
|
||||
/// replace the original.
|
||||
unsigned newArgIdx;
|
||||
|
||||
/// The number of arguments that replaced the original argument.
|
||||
unsigned newArgSize;
|
||||
|
||||
/// The cast value that was created to cast from the new arguments to the
|
||||
/// old. This only used if 'newArgSize' > 1.
|
||||
Value castValue;
|
||||
};
|
||||
|
||||
/// Block type conversion. This rewrite is partially reflected in the IR.
|
||||
class BlockTypeConversionRewrite : public BlockRewrite {
|
||||
public:
|
||||
BlockTypeConversionRewrite(
|
||||
ConversionPatternRewriterImpl &rewriterImpl, Block *block,
|
||||
Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
|
||||
const TypeConverter *converter)
|
||||
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
|
||||
Block *block, Block *origBlock,
|
||||
const TypeConverter *converter)
|
||||
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
|
||||
origBlock(origBlock), argInfo(argInfo), converter(converter) {}
|
||||
origBlock(origBlock), converter(converter) {}
|
||||
|
||||
static bool classof(const IRRewrite *rewrite) {
|
||||
return rewrite->getKind() == Kind::BlockTypeConversion;
|
||||
@ -479,10 +459,6 @@ private:
|
||||
/// The original block that was requested to have its signature converted.
|
||||
Block *origBlock;
|
||||
|
||||
/// The conversion information for each of the arguments. The information is
|
||||
/// std::nullopt if the argument was dropped during conversion.
|
||||
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
|
||||
|
||||
/// The type converter used to convert the arguments.
|
||||
const TypeConverter *converter;
|
||||
};
|
||||
@ -691,12 +667,16 @@ public:
|
||||
/// The type of materialization.
|
||||
enum MaterializationKind {
|
||||
/// This materialization materializes a conversion for an illegal block
|
||||
/// argument type, to a legal one.
|
||||
/// argument type, to the original one.
|
||||
Argument,
|
||||
|
||||
/// This materialization materializes a conversion from an illegal type to a
|
||||
/// legal one.
|
||||
Target
|
||||
Target,
|
||||
|
||||
/// This materialization materializes a conversion from a legal type back to
|
||||
/// an illegal one.
|
||||
Source
|
||||
};
|
||||
|
||||
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
|
||||
@ -736,7 +716,7 @@ public:
|
||||
private:
|
||||
/// The corresponding type converter to use when resolving this
|
||||
/// materialization, and the kind of this materialization.
|
||||
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
|
||||
llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
|
||||
converterAndKind;
|
||||
};
|
||||
} // namespace
|
||||
@ -855,11 +835,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
|
||||
ValueRange inputs, Type outputType,
|
||||
const TypeConverter *converter);
|
||||
|
||||
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
|
||||
ValueRange inputs,
|
||||
Type outputType,
|
||||
const TypeConverter *converter);
|
||||
|
||||
Value buildUnresolvedTargetMaterialization(Location loc, Value input,
|
||||
Type outputType,
|
||||
const TypeConverter *converter);
|
||||
@ -989,28 +964,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
|
||||
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
|
||||
for (Operation *op : block->getUsers())
|
||||
listener->notifyOperationModified(op);
|
||||
|
||||
// Process the remapping for each of the original arguments.
|
||||
for (auto [origArg, info] :
|
||||
llvm::zip_equal(origBlock->getArguments(), argInfo)) {
|
||||
// Handle the case of a 1->0 value mapping.
|
||||
if (!info) {
|
||||
if (Value newArg =
|
||||
rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
|
||||
rewriter.replaceAllUsesWith(origArg, newArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise this is a 1->1+ value mapping.
|
||||
Value castValue = info->castValue;
|
||||
assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
|
||||
|
||||
// If the argument is still used, replace it with the generated cast.
|
||||
if (!origArg.use_empty()) {
|
||||
rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
|
||||
castValue, origArg.getType()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BlockTypeConversionRewrite::rollback() {
|
||||
@ -1035,14 +988,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
|
||||
continue;
|
||||
|
||||
Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
|
||||
bool isDroppedArg = replacementValue == origArg;
|
||||
if (!isDroppedArg)
|
||||
builder.setInsertionPointAfterValue(replacementValue);
|
||||
assert(replacementValue && "replacement value not found");
|
||||
Value newArg;
|
||||
if (converter) {
|
||||
builder.setInsertionPointAfterValue(replacementValue);
|
||||
newArg = converter->materializeSourceConversion(
|
||||
builder, origArg.getLoc(), origArg.getType(),
|
||||
isDroppedArg ? ValueRange() : ValueRange(replacementValue));
|
||||
builder, origArg.getLoc(), origArg.getType(), replacementValue);
|
||||
assert((!newArg || newArg.getType() == origArg.getType()) &&
|
||||
"materialization hook did not provide a value of the expected "
|
||||
"type");
|
||||
@ -1053,8 +1004,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
|
||||
<< "failed to materialize conversion for block argument #"
|
||||
<< it.index() << " that remained live after conversion, type was "
|
||||
<< origArg.getType();
|
||||
if (!isDroppedArg)
|
||||
diag << ", with target type " << replacementValue.getType();
|
||||
diag.attachNote(liveUser->getLoc())
|
||||
<< "see existing live user here: " << *liveUser;
|
||||
return failure();
|
||||
@ -1340,73 +1289,64 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
|
||||
// Replace all uses of the old block with the new block.
|
||||
block->replaceAllUsesWith(newBlock);
|
||||
|
||||
// Remap each of the original arguments as determined by the signature
|
||||
// conversion.
|
||||
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
|
||||
argInfo.resize(origArgCount);
|
||||
|
||||
for (unsigned i = 0; i != origArgCount; ++i) {
|
||||
auto inputMap = signatureConversion.getInputMapping(i);
|
||||
if (!inputMap)
|
||||
continue;
|
||||
BlockArgument origArg = block->getArgument(i);
|
||||
Type origArgType = origArg.getType();
|
||||
|
||||
// If inputMap->replacementValue is not nullptr, then the argument is
|
||||
// dropped and a replacement value is provided to be the remappedValue.
|
||||
if (inputMap->replacementValue) {
|
||||
assert(inputMap->size == 0 &&
|
||||
"invalid to provide a replacement value when the argument isn't "
|
||||
"dropped");
|
||||
mapping.map(origArg, inputMap->replacementValue);
|
||||
std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
|
||||
signatureConversion.getInputMapping(i);
|
||||
if (!inputMap) {
|
||||
// This block argument was dropped and no replacement value was provided.
|
||||
// Materialize a replacement value "out of thin air".
|
||||
Value repl = buildUnresolvedMaterialization(
|
||||
MaterializationKind::Source, newBlock, newBlock->begin(),
|
||||
origArg.getLoc(), /*inputs=*/ValueRange(),
|
||||
/*outputType=*/origArgType, converter);
|
||||
mapping.map(origArg, repl);
|
||||
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, this is a 1->1+ mapping.
|
||||
if (Value repl = inputMap->replacementValue) {
|
||||
// This block argument was dropped and a replacement value was provided.
|
||||
assert(inputMap->size == 0 &&
|
||||
"invalid to provide a replacement value when the argument isn't "
|
||||
"dropped");
|
||||
mapping.map(origArg, repl);
|
||||
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
// This is a 1->1+ mapping. 1->N mappings are not fully supported in the
|
||||
// dialect conversion. Therefore, we need an argument materialization to
|
||||
// turn the replacement block arguments into a single SSA value that can be
|
||||
// used as a replacement.
|
||||
auto replArgs =
|
||||
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
|
||||
Value newArg;
|
||||
Value argMat = buildUnresolvedMaterialization(
|
||||
MaterializationKind::Argument, newBlock, newBlock->begin(),
|
||||
origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
|
||||
mapping.map(origArg, argMat);
|
||||
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
|
||||
|
||||
// If this is a 1->1 mapping and the types of new and replacement arguments
|
||||
// match (i.e. it's an identity map), then the argument is mapped to its
|
||||
// original type.
|
||||
// FIXME: We simply pass through the replacement argument if there wasn't a
|
||||
// converter, which isn't great as it allows implicit type conversions to
|
||||
// appear. We should properly restructure this code to handle cases where a
|
||||
// converter isn't provided and also to properly handle the case where an
|
||||
// argument materialization is actually a temporary source materialization
|
||||
// (e.g. in the case of 1->N).
|
||||
if (replArgs.size() == 1 &&
|
||||
(!converter || replArgs[0].getType() == origArg.getType())) {
|
||||
newArg = replArgs.front();
|
||||
mapping.map(origArg, newArg);
|
||||
} else {
|
||||
// Build argument materialization: new block arguments -> old block
|
||||
// argument type.
|
||||
Value argMat = buildUnresolvedArgumentMaterialization(
|
||||
newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
|
||||
mapping.map(origArg, argMat);
|
||||
|
||||
// Build target materialization: old block argument type -> legal type.
|
||||
// Note: This function returns an "empty" type if no valid conversion to
|
||||
// a legal type exists. In that case, we continue the conversion with the
|
||||
// original block argument type.
|
||||
Type legalOutputType = converter->convertType(origArg.getType());
|
||||
if (legalOutputType && legalOutputType != origArg.getType()) {
|
||||
newArg = buildUnresolvedTargetMaterialization(
|
||||
origArg.getLoc(), argMat, legalOutputType, converter);
|
||||
mapping.map(argMat, newArg);
|
||||
} else {
|
||||
newArg = argMat;
|
||||
}
|
||||
Type legalOutputType;
|
||||
if (converter)
|
||||
legalOutputType = converter->convertType(origArgType);
|
||||
if (legalOutputType && legalOutputType != origArgType) {
|
||||
Value targetMat = buildUnresolvedTargetMaterialization(
|
||||
origArg.getLoc(), argMat, legalOutputType, converter);
|
||||
mapping.map(argMat, targetMat);
|
||||
}
|
||||
|
||||
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
|
||||
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
|
||||
}
|
||||
|
||||
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
|
||||
converter);
|
||||
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
|
||||
|
||||
// Erase the old block. (It is just unlinked for now and will be erased during
|
||||
// cleanup.)
|
||||
@ -1437,13 +1377,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
|
||||
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
|
||||
return convertOp.getResult(0);
|
||||
}
|
||||
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
|
||||
Block *block, Location loc, ValueRange inputs, Type outputType,
|
||||
const TypeConverter *converter) {
|
||||
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
|
||||
block->begin(), loc, inputs, outputType,
|
||||
converter);
|
||||
}
|
||||
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
|
||||
Location loc, Value input, Type outputType,
|
||||
const TypeConverter *converter) {
|
||||
@ -2862,6 +2795,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
|
||||
newMaterialization = converter->materializeTargetConversion(
|
||||
rewriter, op->getLoc(), outputType, inputOperands);
|
||||
break;
|
||||
case MaterializationKind::Source:
|
||||
newMaterialization = converter->materializeSourceConversion(
|
||||
rewriter, op->getLoc(), outputType, inputOperands);
|
||||
break;
|
||||
}
|
||||
if (newMaterialization) {
|
||||
assert(newMaterialization.getType() == outputType &&
|
||||
@ -2874,8 +2811,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
|
||||
|
||||
InFlightDiagnostic diag = op->emitError()
|
||||
<< "failed to legalize unresolved materialization "
|
||||
"from "
|
||||
<< inputOperands.getTypes() << " to " << outputType
|
||||
"from ("
|
||||
<< inputOperands.getTypes() << ") to " << outputType
|
||||
<< " that remained live after conversion";
|
||||
if (Operation *liveUser = findLiveUser(op->getUsers())) {
|
||||
diag.attachNote(liveUser->getLoc())
|
||||
|
@ -2,9 +2,8 @@
|
||||
|
||||
|
||||
func.func @test_invalid_arg_materialization(
|
||||
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
|
||||
// expected-error@below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
|
||||
%arg0: i16) {
|
||||
// expected-note@below {{see existing live user here}}
|
||||
"foo.return"(%arg0) : (i16) -> ()
|
||||
}
|
||||
|
||||
@ -104,9 +103,8 @@ func.func @test_block_argument_not_converted() {
|
||||
// Make sure argument type changes aren't implicitly forwarded.
|
||||
func.func @test_signature_conversion_no_converter() {
|
||||
"test.signature_conversion_no_converter"() ({
|
||||
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
|
||||
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
|
||||
^bb0(%arg0: f32):
|
||||
// expected-note@below {{see existing live user here}}
|
||||
"test.type_consumer"(%arg0) : (f32) -> ()
|
||||
"test.return"(%arg0) : (f32) -> ()
|
||||
}) : () -> ()
|
||||
|
Loading…
x
Reference in New Issue
Block a user