跳到主要内容

CASK Fusion

阐述

  • Why:
    • CASK can provide some fused gemm + gelu_tanh / gelu_erf kernels
    • Myelin match-and-fuse strategy often fail to fuse gelu into gemm
    • Statically compiled CASK kernels is large
    • To be able to fuse reduction operations
    • New CASK functionalities allow Myelin to support epilogue fusion dynamically
  • What
    • Existing optimizations remain intact prior to kgen, giving three forms of gemm:
      • gemm only
      • gemm + bias
      • gemm + other ops
    • kgen will be the pass to perform additional epilogue fusion
  • How
    • Create a CASK IR builder from a gemm shader
    • Add nodes to build a CASK IR shader
    • Build a GraphShader, through XMMA-JIT call and CASK NVRTC
    • Define a fusion_node_t type to replace the subgraph

实例

//  
// Output, a CASK IR shader for gemm and all nodes starting from tgt
// connected by fusible arrows, or a null shader if this fusion is
// not supported.
//
Shader *construct(fdag_arrow_t *arrow) {
src, tgt = arrow
// Collect all nodes reachable by fusible nodes from tgt

// inputs are the source nodes of arrows (A, B) where A is not in
// nodes, and B is in nodes.

// Similarly, outputs are the target nodes of arrows (A, B) where A
// is in nodes and B is not in nodes.

nodes, outputs, inputs = all_fused_nodes(tgt)
// Tracking CASK tensors produced by nodes
shader_data = {}
// Codegen starting from outputs. This recursively traverses
// producers and the recursion is bounded by the input nodes.

for node in outputs
visit(node, shader_data, inputs)

// construct the final CASK shader from input and output tensors.

shader = makeShader(shader_data)
return shader;
}

性质

相关内容

参考文献