[mlir][gpu] Add py binding for AsyncTokenType (#96466)
The PR adds py binding for `AsyncTokenType`
This commit is contained in:
parent
b0bc2f6912
commit
f8ff909471
@ -19,6 +19,14 @@ extern "C" {
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(GPU, gpu);
|
||||
|
||||
//===-------------------------------------------------------------------===//
|
||||
// AsyncTokenType
|
||||
//===-------------------------------------------------------------------===//
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAGPUAsyncTokenType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType mlirGPUAsyncTokenTypeGet(MlirContext ctx);
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// ObjectAttr
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
@ -25,6 +25,20 @@ using namespace mlir::python::adaptors;
|
||||
|
||||
PYBIND11_MODULE(_mlirDialectsGPU, m) {
|
||||
m.doc() = "MLIR GPU Dialect";
|
||||
//===-------------------------------------------------------------------===//
|
||||
// AsyncTokenType
|
||||
//===-------------------------------------------------------------------===//
|
||||
|
||||
auto mlirGPUAsyncTokenType =
|
||||
mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType);
|
||||
|
||||
mlirGPUAsyncTokenType.def_classmethod(
|
||||
"get",
|
||||
[](py::object cls, MlirContext ctx) {
|
||||
return cls(mlirGPUAsyncTokenTypeGet(ctx));
|
||||
},
|
||||
"Gets an instance of AsyncTokenType in the same context", py::arg("cls"),
|
||||
py::arg("ctx") = py::none());
|
||||
|
||||
//===-------------------------------------------------------------------===//
|
||||
// ObjectAttr
|
||||
|
@ -15,6 +15,18 @@ using namespace mlir;
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, gpu::GPUDialect)
|
||||
|
||||
//===-------------------------------------------------------------------===//
|
||||
// AsyncTokenType
|
||||
//===-------------------------------------------------------------------===//
|
||||
|
||||
bool mlirTypeIsAGPUAsyncTokenType(MlirType type) {
|
||||
return isa<gpu::AsyncTokenType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirType mlirGPUAsyncTokenTypeGet(MlirContext ctx) {
|
||||
return wrap(gpu::AsyncTokenType::get(unwrap(ctx)));
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// ObjectAttr
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
@ -23,7 +23,7 @@ import numpy as np
|
||||
@NVDSL.mlir_func
|
||||
def saxpy(x, y, alpha):
|
||||
# 1. Use MLIR GPU dialect to allocate and copy memory
|
||||
token_ty = ir.Type.parse("!gpu.async.token")
|
||||
token_ty = gpu.AsyncTokenType.get()
|
||||
t1 = gpu.wait(token_ty, [])
|
||||
x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
|
||||
y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
|
||||
|
@ -27,7 +27,7 @@ import numpy as np
|
||||
|
||||
@NVDSL.mlir_func
|
||||
def saxpy(x, y, alpha):
|
||||
token_ty = ir.Type.parse("!gpu.async.token")
|
||||
token_ty = gpu.AsyncTokenType.get()
|
||||
t1 = gpu.wait(token_ty, [])
|
||||
x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
|
||||
y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
|
||||
|
@ -59,7 +59,7 @@ def tma_load(
|
||||
|
||||
@NVDSL.mlir_func
|
||||
def gemm_128_128_64(a, b, d):
|
||||
token_ty = ir.Type.parse("!gpu.async.token")
|
||||
token_ty = gpu.AsyncTokenType.get()
|
||||
t1 = gpu.wait(token_ty, [])
|
||||
a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
|
||||
b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
|
||||
|
@ -258,7 +258,7 @@ def epilogue(D: WGMMAMatrix, d_dev):
|
||||
# d -> memref<MxNxf32>
|
||||
@NVDSL.mlir_func
|
||||
def gemm_multistage(a, b, d, num_stages):
|
||||
token_ty = ir.Type.parse("!gpu.async.token")
|
||||
token_ty = gpu.AsyncTokenType.get()
|
||||
t1 = gpu.wait(token_ty, [])
|
||||
a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
|
||||
b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
|
||||
|
@ -252,7 +252,7 @@ def epilogue(D: WGMMAMatrix, d_dev):
|
||||
|
||||
@NVDSL.mlir_func
|
||||
def gemm_warp_specialized(a, b, d, num_stages):
|
||||
token_ty = ir.Type.parse("!gpu.async.token")
|
||||
token_ty = gpu.AsyncTokenType.get()
|
||||
t1 = gpu.wait(token_ty, [])
|
||||
a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
|
||||
b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
|
||||
|
@ -182,7 +182,7 @@ def generate_matmul_ws(
|
||||
assert K % BLOCK_K == 0
|
||||
|
||||
module = ir.Module.create()
|
||||
token_ty = ir.Type.parse("!gpu.async.token")
|
||||
token_ty = gpu.AsyncTokenType.get()
|
||||
a_elem_ty = get_mlir_ty(input_type)
|
||||
b_elem_ty = get_mlir_ty(input_type)
|
||||
c_elem_ty = get_mlir_ty(output_type)
|
||||
@ -682,7 +682,7 @@ def generate_matmul_multistage(
|
||||
assert K % BLOCK_K == 0
|
||||
|
||||
module = ir.Module.create()
|
||||
token_ty = ir.Type.parse("!gpu.async.token")
|
||||
token_ty = gpu.AsyncTokenType.get()
|
||||
a_elem_ty = get_mlir_ty(input_type)
|
||||
b_elem_ty = get_mlir_ty(input_type)
|
||||
c_elem_ty = get_mlir_ty(output_type)
|
||||
|
Loading…
x
Reference in New Issue
Block a user