{"title":"SPLAT: A framework for optimised GPU code-generation for SParse reguLar ATtention","authors":"Ahan Gupta, Yueming Yuan, Devansh Jain, Yuhao Ge, David Aponte, Yanqi Zhou, Charith Mendis","doi":"arxiv-2407.16847","DOIUrl":null,"url":null,"abstract":"Multi-head-self-attention (MHSA) mechanisms achieve state-of-the-art (SOTA)\nperformance across natural language processing and vision tasks. However, their\nquadratic dependence on sequence lengths has bottlenecked inference speeds. To\ncircumvent this bottleneck, researchers have proposed various sparse-MHSA\nmodels, where a subset of full attention is computed. Despite their promise,\ncurrent sparse libraries and compilers do not support high-performance\nimplementations for diverse sparse-MHSA patterns due to the underlying sparse\nformats they operate on. These formats, which are typically designed for\nhigh-performance & scientific computing applications, are either curated for\nextreme amounts of random sparsity (<1% non-zero values), or specific sparsity\npatterns. However, the sparsity patterns in sparse-MHSA are moderately sparse\n(10-50% non-zero values) and varied, resulting in existing sparse-formats\ntrading off generality for performance. We bridge this gap, achieving both generality and performance, by proposing a\nnovel sparse format: affine-compressed-sparse-row (ACSR) and supporting\ncode-generation scheme, SPLAT, that generates high-performance implementations\nfor diverse sparse-MHSA patterns on GPUs. Core to our proposed format and code\ngeneration algorithm is the observation that common sparse-MHSA patterns have\nuniquely regular geometric properties. These properties, which can be analyzed\njust-in-time, expose novel optimizations and tiling strategies that SPLAT\nexploits to generate high-performance implementations for diverse patterns. To\ndemonstrate SPLAT's efficacy, we use it to generate code for various\nsparse-MHSA models, achieving geomean speedups of 2.05x and 4.05x over\nhand-written kernels written in triton and TVM respectively on A100 GPUs.\nMoreover, its interfaces are intuitive and easy to use with existing\nimplementations of MHSA in JAX.","PeriodicalId":501197,"journal":{"name":"arXiv - CS - Programming Languages","volume":null,"pages":null},"PeriodicalIF":0.0000,"publicationDate":"2024-07-23","publicationTypes":"Journal Article","fieldsOfStudy":null,"isOpenAccess":false,"openAccessPdf":"","citationCount":"0","resultStr":null,"platform":"Semanticscholar","paperid":null,"PeriodicalName":"arXiv - CS - Programming Languages","FirstCategoryId":"1085","ListUrlMain":"https://doi.org/arxiv-2407.16847","RegionNum":0,"RegionCategory":null,"ArticlePicture":[],"TitleCN":null,"AbstractTextCN":null,"PMCID":null,"EPubDate":"","PubModel":"","JCR":"","JCRName":"","Score":null,"Total":0}
引用次数: 0
Abstract
Multi-head-self-attention (MHSA) mechanisms achieve state-of-the-art (SOTA)
performance across natural language processing and vision tasks. However, their
quadratic dependence on sequence lengths has bottlenecked inference speeds. To
circumvent this bottleneck, researchers have proposed various sparse-MHSA
models, where a subset of full attention is computed. Despite their promise,
current sparse libraries and compilers do not support high-performance
implementations for diverse sparse-MHSA patterns due to the underlying sparse
formats they operate on. These formats, which are typically designed for
high-performance & scientific computing applications, are either curated for
extreme amounts of random sparsity (<1% non-zero values), or specific sparsity
patterns. However, the sparsity patterns in sparse-MHSA are moderately sparse
(10-50% non-zero values) and varied, resulting in existing sparse-formats
trading off generality for performance. We bridge this gap, achieving both generality and performance, by proposing a
novel sparse format: affine-compressed-sparse-row (ACSR) and supporting
code-generation scheme, SPLAT, that generates high-performance implementations
for diverse sparse-MHSA patterns on GPUs. Core to our proposed format and code
generation algorithm is the observation that common sparse-MHSA patterns have
uniquely regular geometric properties. These properties, which can be analyzed
just-in-time, expose novel optimizations and tiling strategies that SPLAT
exploits to generate high-performance implementations for diverse patterns. To
demonstrate SPLAT's efficacy, we use it to generate code for various
sparse-MHSA models, achieving geomean speedups of 2.05x and 4.05x over
hand-written kernels written in triton and TVM respectively on A100 GPUs.
Moreover, its interfaces are intuitive and easy to use with existing
implementations of MHSA in JAX.