159 lines
5.0 KiB
C++
159 lines
5.0 KiB
C++
//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
|
|
// automatically. It is used by NVVM to LLVM pass.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
|
|
#define DEBUG_TYPE "ptx-builder"
|
|
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
|
|
#define DBGSNL() (llvm::dbgs() << "\n")
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BasicPtxBuilderInterface
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc"
|
|
|
|
using namespace mlir;
|
|
using namespace NVVM;
|
|
|
|
static constexpr int64_t kSharedMemorySpace = 3;
|
|
|
|
static char getRegisterType(Type type) {
|
|
if (type.isInteger(1))
|
|
return 'b';
|
|
if (type.isInteger(16))
|
|
return 'h';
|
|
if (type.isInteger(32))
|
|
return 'r';
|
|
if (type.isInteger(64))
|
|
return 'l';
|
|
if (type.isF32())
|
|
return 'f';
|
|
if (type.isF64())
|
|
return 'd';
|
|
if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
|
|
// Shared address spaces is addressed with 32-bit pointers.
|
|
if (ptr.getAddressSpace() == kSharedMemorySpace) {
|
|
return 'r';
|
|
}
|
|
return 'l';
|
|
}
|
|
// register type for struct is not supported.
|
|
llvm_unreachable("The register type could not deduced from MLIR type");
|
|
return '?';
|
|
}
|
|
|
|
static char getRegisterType(Value v) {
|
|
if (v.getDefiningOp<LLVM::ConstantOp>())
|
|
return 'n';
|
|
return getRegisterType(v.getType());
|
|
}
|
|
|
|
void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
|
|
LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
|
|
auto getModifier = [&]() -> const char * {
|
|
if (itype == PTXRegisterMod::ReadWrite) {
|
|
assert(false && "Read-Write modifier is not supported. Try setting the "
|
|
"same value as Write and Read separately.");
|
|
return "+";
|
|
}
|
|
if (itype == PTXRegisterMod::Write) {
|
|
return "=";
|
|
}
|
|
return "";
|
|
};
|
|
auto addValue = [&](Value v) {
|
|
if (itype == PTXRegisterMod::Read) {
|
|
ptxOperands.push_back(v);
|
|
return;
|
|
}
|
|
if (itype == PTXRegisterMod::ReadWrite)
|
|
ptxOperands.push_back(v);
|
|
hasResult = true;
|
|
};
|
|
|
|
llvm::raw_string_ostream ss(registerConstraints);
|
|
// Handle Structs
|
|
if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
|
|
if (itype == PTXRegisterMod::Write) {
|
|
addValue(v);
|
|
}
|
|
for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
|
|
if (itype != PTXRegisterMod::Write) {
|
|
Value extractValue = rewriter.create<LLVM::ExtractValueOp>(
|
|
interfaceOp->getLoc(), v, idx);
|
|
addValue(extractValue);
|
|
}
|
|
if (itype == PTXRegisterMod::ReadWrite) {
|
|
ss << idx << ",";
|
|
} else {
|
|
ss << getModifier() << getRegisterType(t) << ",";
|
|
}
|
|
ss.flush();
|
|
}
|
|
return;
|
|
}
|
|
// Handle Scalars
|
|
addValue(v);
|
|
ss << getModifier() << getRegisterType(v) << ",";
|
|
ss.flush();
|
|
}
|
|
|
|
LLVM::InlineAsmOp PtxBuilder::build() {
|
|
auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
|
|
LLVM::AsmDialect::AD_ATT);
|
|
|
|
auto resultTypes = interfaceOp->getResultTypes();
|
|
|
|
// Remove the last comma from the constraints string.
|
|
if (!registerConstraints.empty() &&
|
|
registerConstraints[registerConstraints.size() - 1] == ',')
|
|
registerConstraints.pop_back();
|
|
|
|
std::string ptxInstruction = interfaceOp.getPtx();
|
|
|
|
// Add the predicate to the asm string.
|
|
if (interfaceOp.getPredicate().has_value() &&
|
|
interfaceOp.getPredicate().value()) {
|
|
std::string predicateStr = "@%";
|
|
predicateStr += std::to_string((ptxOperands.size() - 1));
|
|
ptxInstruction = predicateStr + " " + ptxInstruction;
|
|
}
|
|
|
|
// Tablegen doesn't accept $, so we use %, but inline assembly uses $.
|
|
// Replace all % with $
|
|
std::replace(ptxInstruction.begin(), ptxInstruction.end(), '%', '$');
|
|
|
|
return rewriter.create<LLVM::InlineAsmOp>(
|
|
interfaceOp->getLoc(),
|
|
/*result types=*/resultTypes,
|
|
/*operands=*/ptxOperands,
|
|
/*asm_string=*/llvm::StringRef(ptxInstruction),
|
|
/*constraints=*/registerConstraints.data(),
|
|
/*has_side_effects=*/interfaceOp.hasSideEffect(),
|
|
/*is_align_stack=*/false,
|
|
/*asm_dialect=*/asmDialectAttr,
|
|
/*operand_attrs=*/ArrayAttr());
|
|
}
|
|
|
|
void PtxBuilder::buildAndReplaceOp() {
|
|
LLVM::InlineAsmOp inlineAsmOp = build();
|
|
LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
|
|
if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
|
|
rewriter.replaceOp(interfaceOp, inlineAsmOp);
|
|
} else {
|
|
rewriter.eraseOp(interfaceOp);
|
|
}
|
|
}
|