[mlir][gpu] Add py binding for AsyncTokenType (#96466)

The PR adds py binding for `AsyncTokenType`
This commit is contained in:
Guray Ozen 2024-06-24 11:39:22 +02:00 committed by GitHub
parent b0bc2f6912
commit f8ff909471
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 41 additions and 7 deletions

View File

@ -19,6 +19,14 @@ extern "C" {
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(GPU, gpu);
//===-------------------------------------------------------------------===//
// AsyncTokenType
//===-------------------------------------------------------------------===//
MLIR_CAPI_EXPORTED bool mlirTypeIsAGPUAsyncTokenType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirGPUAsyncTokenTypeGet(MlirContext ctx);
//===---------------------------------------------------------------------===//
// ObjectAttr
//===---------------------------------------------------------------------===//

View File

@ -25,6 +25,20 @@ using namespace mlir::python::adaptors;
PYBIND11_MODULE(_mlirDialectsGPU, m) {
m.doc() = "MLIR GPU Dialect";
//===-------------------------------------------------------------------===//
// AsyncTokenType
//===-------------------------------------------------------------------===//
auto mlirGPUAsyncTokenType =
mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType);
mlirGPUAsyncTokenType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
return cls(mlirGPUAsyncTokenTypeGet(ctx));
},
"Gets an instance of AsyncTokenType in the same context", py::arg("cls"),
py::arg("ctx") = py::none());
//===-------------------------------------------------------------------===//
// ObjectAttr

View File

@ -15,6 +15,18 @@ using namespace mlir;
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, gpu::GPUDialect)
//===-------------------------------------------------------------------===//
// AsyncTokenType
//===-------------------------------------------------------------------===//
bool mlirTypeIsAGPUAsyncTokenType(MlirType type) {
return isa<gpu::AsyncTokenType>(unwrap(type));
}
MlirType mlirGPUAsyncTokenTypeGet(MlirContext ctx) {
return wrap(gpu::AsyncTokenType::get(unwrap(ctx)));
}
//===---------------------------------------------------------------------===//
// ObjectAttr
//===---------------------------------------------------------------------===//

View File

@ -23,7 +23,7 @@ import numpy as np
@NVDSL.mlir_func
def saxpy(x, y, alpha):
# 1. Use MLIR GPU dialect to allocate and copy memory
token_ty = ir.Type.parse("!gpu.async.token")
token_ty = gpu.AsyncTokenType.get()
t1 = gpu.wait(token_ty, [])
x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])

View File

@ -27,7 +27,7 @@ import numpy as np
@NVDSL.mlir_func
def saxpy(x, y, alpha):
token_ty = ir.Type.parse("!gpu.async.token")
token_ty = gpu.AsyncTokenType.get()
t1 = gpu.wait(token_ty, [])
x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])

View File

@ -59,7 +59,7 @@ def tma_load(
@NVDSL.mlir_func
def gemm_128_128_64(a, b, d):
token_ty = ir.Type.parse("!gpu.async.token")
token_ty = gpu.AsyncTokenType.get()
t1 = gpu.wait(token_ty, [])
a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])

View File

@ -258,7 +258,7 @@ def epilogue(D: WGMMAMatrix, d_dev):
# d -> memref<MxNxf32>
@NVDSL.mlir_func
def gemm_multistage(a, b, d, num_stages):
token_ty = ir.Type.parse("!gpu.async.token")
token_ty = gpu.AsyncTokenType.get()
t1 = gpu.wait(token_ty, [])
a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])

View File

@ -252,7 +252,7 @@ def epilogue(D: WGMMAMatrix, d_dev):
@NVDSL.mlir_func
def gemm_warp_specialized(a, b, d, num_stages):
token_ty = ir.Type.parse("!gpu.async.token")
token_ty = gpu.AsyncTokenType.get()
t1 = gpu.wait(token_ty, [])
a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])

View File

@ -182,7 +182,7 @@ def generate_matmul_ws(
assert K % BLOCK_K == 0
module = ir.Module.create()
token_ty = ir.Type.parse("!gpu.async.token")
token_ty = gpu.AsyncTokenType.get()
a_elem_ty = get_mlir_ty(input_type)
b_elem_ty = get_mlir_ty(input_type)
c_elem_ty = get_mlir_ty(output_type)
@ -682,7 +682,7 @@ def generate_matmul_multistage(
assert K % BLOCK_K == 0
module = ir.Module.create()
token_ty = ir.Type.parse("!gpu.async.token")
token_ty = gpu.AsyncTokenType.get()
a_elem_ty = get_mlir_ty(input_type)
b_elem_ty = get_mlir_ty(input_type)
c_elem_ty = get_mlir_ty(output_type)