Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,17 @@ python ./examples/run_music_generation.py --model_path=./ckpt --version="3B"

By default this command will generate a piece of music conditioned on lyrics and tags provided in `./assets` folder. The output music will be saved at `./assets/output.mp3`.

#### Use torch.compile for faster generations:

```
python ./examples/run_music_generation.py --model_path=./ckpt --version="3B" --compile --compile_mode max-autotune
```

_Note_: The first run using torch.compile may take longer as it compiles before inference, but once compiled it will not need to re-compile on the next run.
Expected performance improvement using torch.compile is ~2x.
If you are on Windows, you may need to install triton-windows `pip install -U "triton-windows>=3.2,<3.3"`


#### FAQs

1. How to specify lyrics and tags?
Expand Down
7 changes: 7 additions & 0 deletions examples/run_music_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def parse_args():
parser.add_argument("--mula_dtype", type=str2dtype, default="bfloat16")
parser.add_argument("--codec_dtype", type=str2dtype, default="float32")
parser.add_argument("--lazy_load", type=str2bool, default=False)
parser.add_argument("--compile", action="store_true",
help="Enable torch.compile for faster inference")
parser.add_argument("--compile_mode", type=str, default="default",
choices=["default", "reduce-overhead", "max-autotune"],
help="torch.compile mode")
return parser.parse_args()


Expand All @@ -65,6 +70,8 @@ def parse_args():
},
version=args.version,
lazy_load=args.lazy_load,
compile_model=args.compile,
compile_mode=args.compile_mode,
)
with torch.no_grad():
pipe(
Expand Down
55 changes: 55 additions & 0 deletions src/heartlib/pipelines/music_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@
import json
from contextlib import contextmanager
import gc
import warnings


def _get_compile_backend(requested_backend: Optional[str]) -> str:
"""Determine best available backend for torch.compile."""
if requested_backend:
return requested_backend

# Check if triton is available (triton-windows on Windows)
try:
import triton
return "inductor" # Full optimization with triton
except ImportError:
warnings.warn(
"Triton not found. On Windows, install triton-windows for best performance: "
"pip install -U triton-windows>=3.2,<3.3'. Falling back to eager backend."
)
return "eager"


def _resolve_paths(pretrained_path: str, version: str):
Expand Down Expand Up @@ -94,6 +112,9 @@ def __init__(
muq_mulan: Optional[Any],
text_tokenizer: Tokenizer,
config: HeartMuLaGenConfig,
compile_model: bool = False,
compile_backend: Optional[str] = None,
compile_mode: str = "default",
):

self.muq_mulan = muq_mulan
Expand All @@ -111,6 +132,10 @@ def __init__(
self.codec_path = heartcodec_path
self.codec_device = heartcodec_device

self._compile_model = compile_model
self._compile_backend = compile_backend
self._compile_mode = compile_mode

self._mula: Optional[HeartMuLa] = None
self._codec: Optional[HeartCodec] = None
if not lazy_load:
Expand All @@ -122,13 +147,36 @@ def __init__(
device_map=self.mula_device,
dtype=self.mula_dtype,
)
self._apply_compile(self._mula)
self._codec = HeartCodec.from_pretrained(
self.codec_path,
device_map=self.codec_device,
dtype=self.codec_dtype,
)
self.lazy_load = lazy_load

def _apply_compile(self, model: HeartMuLa):
"""Apply torch.compile to backbone and decoder if requested."""
if not self._compile_model:
return
try:
backend = _get_compile_backend(self._compile_backend)
model.backbone = torch.compile(
model.backbone,
backend=backend,
mode=self._compile_mode,
dynamic=True,
)
model.decoder = torch.compile(
model.decoder,
backend=backend,
mode=self._compile_mode,
dynamic=True,
)
print(f"Backbone and decoder compiled with backend={backend}, mode={self._compile_mode}")
except Exception as e:
warnings.warn(f"torch.compile failed ({e}), continuing without compilation")

@property
def mula(self) -> HeartMuLa:
if isinstance(self._mula, HeartMuLa):
Expand All @@ -138,6 +186,7 @@ def mula(self) -> HeartMuLa:
device_map=self.mula_device,
dtype=self.mula_dtype,
)
self._apply_compile(self._mula)
return self._mula

@property
Expand Down Expand Up @@ -357,6 +406,9 @@ def from_pretrained(
dtype: Union[torch.dtype, Dict[str, torch.dtype]],
version: str,
lazy_load: bool = False,
compile_model: bool = False,
compile_backend: Optional[str] = None,
compile_mode: str = "default",
):

mula_path, codec_path, tokenizer_path, gen_config_path = _resolve_paths(
Expand All @@ -380,4 +432,7 @@ def from_pretrained(
config=gen_config,
heartmula_dtype=mula_dtype,
heartcodec_dtype=codec_dtype,
compile_model=compile_model,
compile_backend=compile_backend,
compile_mode=compile_mode,
)