llvm-project/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
2024-06-13 20:20:27 +01:00

159 lines
5.0 KiB
C++

//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
// automatically. It is used by NVVM to LLVM pass.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
#include "mlir/Support/LogicalResult.h"
#define DEBUG_TYPE "ptx-builder"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")
//===----------------------------------------------------------------------===//
// BasicPtxBuilderInterface
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc"
using namespace mlir;
using namespace NVVM;
static constexpr int64_t kSharedMemorySpace = 3;
static char getRegisterType(Type type) {
if (type.isInteger(1))
return 'b';
if (type.isInteger(16))
return 'h';
if (type.isInteger(32))
return 'r';
if (type.isInteger(64))
return 'l';
if (type.isF32())
return 'f';
if (type.isF64())
return 'd';
if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
// Shared address spaces is addressed with 32-bit pointers.
if (ptr.getAddressSpace() == kSharedMemorySpace) {
return 'r';
}
return 'l';
}
// register type for struct is not supported.
llvm_unreachable("The register type could not deduced from MLIR type");
return '?';
}
static char getRegisterType(Value v) {
if (v.getDefiningOp<LLVM::ConstantOp>())
return 'n';
return getRegisterType(v.getType());
}
void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
auto getModifier = [&]() -> const char * {
if (itype == PTXRegisterMod::ReadWrite) {
assert(false && "Read-Write modifier is not supported. Try setting the "
"same value as Write and Read separately.");
return "+";
}
if (itype == PTXRegisterMod::Write) {
return "=";
}
return "";
};
auto addValue = [&](Value v) {
if (itype == PTXRegisterMod::Read) {
ptxOperands.push_back(v);
return;
}
if (itype == PTXRegisterMod::ReadWrite)
ptxOperands.push_back(v);
hasResult = true;
};
llvm::raw_string_ostream ss(registerConstraints);
// Handle Structs
if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
if (itype == PTXRegisterMod::Write) {
addValue(v);
}
for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
if (itype != PTXRegisterMod::Write) {
Value extractValue = rewriter.create<LLVM::ExtractValueOp>(
interfaceOp->getLoc(), v, idx);
addValue(extractValue);
}
if (itype == PTXRegisterMod::ReadWrite) {
ss << idx << ",";
} else {
ss << getModifier() << getRegisterType(t) << ",";
}
ss.flush();
}
return;
}
// Handle Scalars
addValue(v);
ss << getModifier() << getRegisterType(v) << ",";
ss.flush();
}
LLVM::InlineAsmOp PtxBuilder::build() {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
LLVM::AsmDialect::AD_ATT);
auto resultTypes = interfaceOp->getResultTypes();
// Remove the last comma from the constraints string.
if (!registerConstraints.empty() &&
registerConstraints[registerConstraints.size() - 1] == ',')
registerConstraints.pop_back();
std::string ptxInstruction = interfaceOp.getPtx();
// Add the predicate to the asm string.
if (interfaceOp.getPredicate().has_value() &&
interfaceOp.getPredicate().value()) {
std::string predicateStr = "@%";
predicateStr += std::to_string((ptxOperands.size() - 1));
ptxInstruction = predicateStr + " " + ptxInstruction;
}
// Tablegen doesn't accept $, so we use %, but inline assembly uses $.
// Replace all % with $
std::replace(ptxInstruction.begin(), ptxInstruction.end(), '%', '$');
return rewriter.create<LLVM::InlineAsmOp>(
interfaceOp->getLoc(),
/*result types=*/resultTypes,
/*operands=*/ptxOperands,
/*asm_string=*/llvm::StringRef(ptxInstruction),
/*constraints=*/registerConstraints.data(),
/*has_side_effects=*/interfaceOp.hasSideEffect(),
/*is_align_stack=*/false,
/*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
}
void PtxBuilder::buildAndReplaceOp() {
LLVM::InlineAsmOp inlineAsmOp = build();
LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
rewriter.replaceOp(interfaceOp, inlineAsmOp);
} else {
rewriter.eraseOp(interfaceOp);
}
}