diff --git a/README.md b/README.md index 36ecc6d..5f2ae17 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,17 @@ python ./examples/run_music_generation.py --model_path=./ckpt --version="3B" By default this command will generate a piece of music conditioned on lyrics and tags provided in `./assets` folder. The output music will be saved at `./assets/output.mp3`. +#### Use torch.compile for faster generations: + +``` +python ./examples/run_music_generation.py --model_path=./ckpt --version="3B" --compile --compile_mode max-autotune +``` + +_Note_: The first run using torch.compile may take longer as it compiles before inference, but once compiled it will not need to re-compile on the next run. +Expected performance improvement using torch.compile is ~2x. +If you are on Windows, you may need to install triton-windows `pip install -U "triton-windows>=3.2,<3.3"` + + #### FAQs 1. How to specify lyrics and tags? diff --git a/examples/run_music_generation.py b/examples/run_music_generation.py index 84924e7..a30cc67 100644 --- a/examples/run_music_generation.py +++ b/examples/run_music_generation.py @@ -48,6 +48,11 @@ def parse_args(): parser.add_argument("--mula_dtype", type=str2dtype, default="bfloat16") parser.add_argument("--codec_dtype", type=str2dtype, default="float32") parser.add_argument("--lazy_load", type=str2bool, default=False) + parser.add_argument("--compile", action="store_true", + help="Enable torch.compile for faster inference") + parser.add_argument("--compile_mode", type=str, default="default", + choices=["default", "reduce-overhead", "max-autotune"], + help="torch.compile mode") return parser.parse_args() @@ -65,6 +70,8 @@ def parse_args(): }, version=args.version, lazy_load=args.lazy_load, + compile_model=args.compile, + compile_mode=args.compile_mode, ) with torch.no_grad(): pipe( diff --git a/src/heartlib/pipelines/music_generation.py b/src/heartlib/pipelines/music_generation.py index c9111ff..375ac5a 100644 --- a/src/heartlib/pipelines/music_generation.py +++ b/src/heartlib/pipelines/music_generation.py @@ -10,6 +10,24 @@ import json from contextlib import contextmanager import gc +import warnings + + +def _get_compile_backend(requested_backend: Optional[str]) -> str: + """Determine best available backend for torch.compile.""" + if requested_backend: + return requested_backend + + # Check if triton is available (triton-windows on Windows) + try: + import triton + return "inductor" # Full optimization with triton + except ImportError: + warnings.warn( + "Triton not found. On Windows, install triton-windows for best performance: " + "pip install -U triton-windows>=3.2,<3.3'. Falling back to eager backend." + ) + return "eager" def _resolve_paths(pretrained_path: str, version: str): @@ -94,6 +112,9 @@ def __init__( muq_mulan: Optional[Any], text_tokenizer: Tokenizer, config: HeartMuLaGenConfig, + compile_model: bool = False, + compile_backend: Optional[str] = None, + compile_mode: str = "default", ): self.muq_mulan = muq_mulan @@ -111,6 +132,10 @@ def __init__( self.codec_path = heartcodec_path self.codec_device = heartcodec_device + self._compile_model = compile_model + self._compile_backend = compile_backend + self._compile_mode = compile_mode + self._mula: Optional[HeartMuLa] = None self._codec: Optional[HeartCodec] = None if not lazy_load: @@ -122,6 +147,7 @@ def __init__( device_map=self.mula_device, dtype=self.mula_dtype, ) + self._apply_compile(self._mula) self._codec = HeartCodec.from_pretrained( self.codec_path, device_map=self.codec_device, @@ -129,6 +155,28 @@ def __init__( ) self.lazy_load = lazy_load + def _apply_compile(self, model: HeartMuLa): + """Apply torch.compile to backbone and decoder if requested.""" + if not self._compile_model: + return + try: + backend = _get_compile_backend(self._compile_backend) + model.backbone = torch.compile( + model.backbone, + backend=backend, + mode=self._compile_mode, + dynamic=True, + ) + model.decoder = torch.compile( + model.decoder, + backend=backend, + mode=self._compile_mode, + dynamic=True, + ) + print(f"Backbone and decoder compiled with backend={backend}, mode={self._compile_mode}") + except Exception as e: + warnings.warn(f"torch.compile failed ({e}), continuing without compilation") + @property def mula(self) -> HeartMuLa: if isinstance(self._mula, HeartMuLa): @@ -138,6 +186,7 @@ def mula(self) -> HeartMuLa: device_map=self.mula_device, dtype=self.mula_dtype, ) + self._apply_compile(self._mula) return self._mula @property @@ -357,6 +406,9 @@ def from_pretrained( dtype: Union[torch.dtype, Dict[str, torch.dtype]], version: str, lazy_load: bool = False, + compile_model: bool = False, + compile_backend: Optional[str] = None, + compile_mode: str = "default", ): mula_path, codec_path, tokenizer_path, gen_config_path = _resolve_paths( @@ -380,4 +432,7 @@ def from_pretrained( config=gen_config, heartmula_dtype=mula_dtype, heartcodec_dtype=codec_dtype, + compile_model=compile_model, + compile_backend=compile_backend, + compile_mode=compile_mode, )