Skip to content

meshy-dev/FlexGEMM

Β 
Β 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

68 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

FlexGEMM

License: MIT Triton PyTorch

FlexGEMM is a high-performance, Triton-powered GEMM backend designed for 3D sparse convolutions.

It implements Explicit, Implicit, and Masked Implicit algorithm variants, featuring optional Split-K parallelism for sparse GEMM. FlexGEMM delivers state-of-the-art performance for Submanifold Convolution and voxel-based neural networks, consistently outperforming existing solutions.

Resources

✨ Why FlexGEMM?

  • Triton-First Architecture: Built entirely on Triton, ensuring high-performance kernel execution and cross-platform compatibility.
  • Sparse-Optimized: Specifically tailored for 3D sparse tensors, efficiently handling highly irregular sparsity patterns.
  • Blazing Fast: Consistently outperforms standard sparse convolution libraries (such as spconv, torchsparse) in training throughput.

πŸ› οΈ Installation

Prerequisites

  • PyTorch β‰₯ 2.4.0
  • Triton β‰₯ 3.2.0

[WIP] BF16 precision support is under development on this branch.

Install via pip

git clone https://github.com/JeffreyXiang/FlexGEMM.git
cd FlexGEMM
pip install .

The wheel is pure Python (py3-none-any): CUDA sources ship under flex_gemm/kernels/cuda/ and the pybind extension is JIT-built on first use of a native op (see flex_gemm/kernels/_cuda_jit.py). The build backend is Hatchling; PyTorch is not required at build time (only at runtime).

Editable install: pip install -e .

πŸ’» Usage Example

Here is a minimal example demonstrating how to perform a sparse submanifold convolution using FlexGEMM:

import torch
import flex_gemm
from flex_gemm.ops.spconv import sparse_submanifold_conv3d
from tests.spconv_fwd import sphere_coords

# 1. Prepare Sparse Voxel Data
# Generate a sparse voxel shell
feats, coords, shape = sphere_coords(256, 256, dtype=torch.float16, device='cuda')

# 2. Define Weights and Bias
Ci, Co = 256, 256
Ks = 3
weight = torch.randn(Co, Ks, Ks, Ks, Ci, dtype=torch.float16, device='cuda', requires_grad=True)
bias = torch.randn(Co, dtype=torch.float16, device='cuda', requires_grad=True)

# 3. Configure Algorithm
# Example: Using Masked Implicit GEMM with Split-K optimization
flex_gemm.ops.spconv.set_algorithm(
    flex_gemm.ops.spconv.Algorithm.MASKED_IMPLICIT_GEMM_SPLITK
)

# 4. Forward Pass
out_feats, neighbor_cache = sparse_submanifold_conv3d(
    feats, coords, shape,
    weight, bias,
)

# 5. Backward Pass
out_feats.sum().backward()

Using with torch.compile

FlexGEMM supports torch.compile via custom op wrappers. The key idea is to separate geometry preparation from computation: build the neighbor cache once from geometry (outside compile), freeze it into a SpConvConfig, then use that config inside the compiled region.

import torch
import flex_gemm
from flex_gemm.ops.spconv import sparse_submanifold_conv3d
from flex_gemm.ops.spconv.submanifold_conv3d import SubMConv3dFunction

# --- Phase 1: Preparation (outside torch.compile, run once) ---

feats, coords, shape = ...  # your sparse voxel data
weight = torch.randn(Co, Ks, Ks, Ks, Ci, device='cuda', requires_grad=True)
bias = torch.randn(Co, device='cuda', requires_grad=True)

# Build neighbor cache directly from geometry (no forward pass needed).
# Uses the default algorithm (MASKED_IMPLICIT_GEMM_SPLITK).
neighbor_cache = SubMConv3dFunction._compute_neighbor_cache(
    coords, shape, (Ks, Ks, Ks), (1, 1, 1),
)

# Freeze: pre-computes all block-size variants, returns a compile-friendly config
config = neighbor_cache.freeze()

# --- Phase 2: Compiled training loop ---

@torch.compile
def train_step(feats, weight, bias):
    # Pass config= to use the compiled path (returns output only, no cache)
    out = sparse_submanifold_conv3d(feats, weight=weight, bias=bias, config=config)
    return out.sum()

loss = train_step(feats, weight, bias)
loss.backward()

Note: The config= path is only needed for torch.compile. The legacy API (sparse_submanifold_conv3d(feats, coords, shape, weight, bias)) continues to work unchanged for eager execution.

πŸ“Š Performance

FlexGEMM demonstrates significant speed improvements over existing baselines.

Test Environment:

  • GPU: NVIDIA A100 80GB PCIe
  • Software: PyTorch 2.4.1, CUDA 12.0, Triton 3.2.0

Benchmark Results

Note: FlexGEMM achieves ~2Γ— acceleration compared to previous state-of-the-art methods under efficient data formats like FP16 and TF32.

1. FP16 Precision (Training Speed)

2. TF32 Precision (Training Speed)

3. FP32 Precision (Training Speed)

Performance Summary

  • SOTA Speed: Consistently outperforms spconv, torchsparse, and fvdb.
  • Scalability: Robust performance across various channel widths (C=64 to C=1024) and resolutions (RES=8 to RES=1024).
  • Memory Efficient: Delivers higher throughput without increasing GPU memory overhead.
  • Application Ready: Ideal for high-resolution voxelized point clouds, submanifold convolutions, and large-scale 3D networks.

🀝 Contributing

We welcome contributions to make FlexGEMM faster and more robust!

How to help

  • Report Bugs: Open an issue describing the bug and how to reproduce it.
  • Suggest Features: Have an idea for a new algorithm or optimization? Let us know!
  • Submit Pull Requests:
    1. Fork the repository and create your branch from main.
    2. Ensure your code follows the project's style.
    3. Run the tests in the tests/ directory to ensure no regressions.
    4. Open a Pull Request with a detailed description.

We appreciate all contributors who help improve this project!

πŸ“œ License

This project is released under the MIT License.

About

BF16 support for FlexGEMM.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 85.4%
  • Cuda 12.0%
  • C++ 2.5%
  • C 0.1%