Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions examples/run_music_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@ def parse_args():

parser.add_argument("--max_audio_length_ms", type=int, default=240_000)
parser.add_argument("--topk", type=int, default=50)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--temperature", type=float, default=1.0,
help="Sampling temperature (also used as start value for dynamic scheduling)")
parser.add_argument("--temperature_end", type=float, default=None,
help="Ending temperature for dynamic scheduling")
parser.add_argument("--temperature_schedule", type=str, default="linear",
choices=["linear", "cosine"],
help="Temperature interpolation schedule")
parser.add_argument("--cfg_scale", type=float, default=1.5)
parser.add_argument("--mula_device", type=str2device, default="cuda")
parser.add_argument("--codec_device", type=str2device, default="cuda")
Expand All @@ -66,6 +72,17 @@ def parse_args():
version=args.version,
lazy_load=args.lazy_load,
)

if args.temperature_end is not None:
temperature = {
"start": args.temperature,
"end": args.temperature_end,
"schedule": args.temperature_schedule,
}
print(f"Using dynamic temperature: {args.temperature} -> {args.temperature_end} ({args.temperature_schedule})")
else:
temperature = args.temperature

with torch.no_grad():
pipe(
{
Expand All @@ -75,7 +92,7 @@ def parse_args():
max_audio_length_ms=args.max_audio_length_ms,
save_path=args.save_path,
topk=args.topk,
temperature=args.temperature,
temperature=temperature,
cfg_scale=args.cfg_scale,
)
print(f"Generated music saved to {args.save_path}")
20 changes: 16 additions & 4 deletions src/heartlib/pipelines/music_generation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from tokenizers import Tokenizer
from ..heartmula.modeling_heartmula import HeartMuLa
from ..heartcodec.modeling_heartcodec import HeartCodec
from ..utils.temperature_schedule import parse_temperature_spec, compute_temperature
import torch
from typing import Dict, Any, Optional, Union
import os
Expand Down Expand Up @@ -182,9 +183,12 @@ def _unload(self):

def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {"cfg_scale": kwargs.get("cfg_scale", 1.5)}
temperature = kwargs.get("temperature", 1.0)
temp_config = parse_temperature_spec(temperature)

forward_kwargs = {
"max_audio_length_ms": kwargs.get("max_audio_length_ms", 120_000),
"temperature": kwargs.get("temperature", 1.0),
"temp_config": temp_config,
"topk": kwargs.get("topk", 50),
"cfg_scale": kwargs.get("cfg_scale", 1.5),
}
Expand Down Expand Up @@ -268,7 +272,7 @@ def _forward(
self,
model_inputs: Dict[str, Any],
max_audio_length_ms: int,
temperature: float,
temp_config: Dict[str, Any],
topk: int,
cfg_scale: float,
):
Expand All @@ -286,7 +290,7 @@ def _forward(
tokens=prompt_tokens,
tokens_mask=prompt_tokens_mask,
input_pos=prompt_pos,
temperature=temperature,
temperature=temp_config["start"],
topk=topk,
cfg_scale=cfg_scale,
continuous_segments=continuous_segment,
Expand Down Expand Up @@ -314,6 +318,14 @@ def _pad_audio_token(token: torch.Tensor):
max_audio_frames = max_audio_length_ms // 80

for i in tqdm(range(max_audio_frames)):
progress = i / max_audio_frames
current_temp = compute_temperature(
progress,
temp_config["start"],
temp_config["end"],
temp_config["schedule"],
)

curr_token, curr_token_mask = _pad_audio_token(curr_token)
with torch.autocast(
device_type=self.mula_device.type, dtype=self.mula_dtype
Expand All @@ -322,7 +334,7 @@ def _pad_audio_token(token: torch.Tensor):
tokens=curr_token,
tokens_mask=curr_token_mask,
input_pos=prompt_pos[..., -1:] + i + 1,
temperature=temperature,
temperature=current_temp,
topk=topk,
cfg_scale=cfg_scale,
continuous_segments=None,
Expand Down
12 changes: 12 additions & 0 deletions src/heartlib/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Utility functions for HeartLib."""
from .temperature_schedule import (
TemperatureSpec,
parse_temperature_spec,
compute_temperature,
)

__all__ = [
"TemperatureSpec",
"parse_temperature_spec",
"compute_temperature",
]
80 changes: 80 additions & 0 deletions src/heartlib/utils/temperature_schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Temperature scheduling utility

This module helps adjust temperature throughout an inference run.

The purpose of this is to allow inference to with a higher temperature and gradually reduce it to mitigate precision loss accumulation for longer runs.

"""
from typing import Union, Tuple, Dict, Any
import math

TemperatureSpec = Union[float, Tuple[float, float], Dict[str, Any]]


def parse_temperature_spec(temp_spec: TemperatureSpec) -> Dict[str, Any]:
"""
Parse temperature specification into normalized config.

Args:
temp_spec: Either a float, tuple (start, end), or dict with keys:
- start: Starting temperature
- end: Ending temperature
- schedule: "linear" or "cosine" (default: "linear")

Returns:
Dict with keys: start, end, schedule, is_dynamic
"""
if isinstance(temp_spec, (int, float)):
return {
"start": float(temp_spec),
"end": float(temp_spec),
"schedule": "linear",
"is_dynamic": False,
}
elif isinstance(temp_spec, tuple) and len(temp_spec) == 2:
return {
"start": float(temp_spec[0]),
"end": float(temp_spec[1]),
"schedule": "linear",
"is_dynamic": True,
}
elif isinstance(temp_spec, dict):
start = float(temp_spec.get("start", 1.0))
end = float(temp_spec.get("end", start))
return {
"start": start,
"end": end,
"schedule": temp_spec.get("schedule", "linear"),
"is_dynamic": start != end,
}
else:
raise ValueError(f"Invalid temperature spec: {temp_spec}")


def compute_temperature(
progress: float,
start: float,
end: float,
schedule: str = "linear",
) -> float:
"""
Compute temperature at a given progress point.

Args:
progress: Generation progress from 0.0 to 1.0
start: Starting temperature
end: Ending temperature
schedule: Interpolation method ("linear" or "cosine")

Returns:
Interpolated temperature value
"""
progress = max(0.0, min(1.0, progress)) # Clamp to [0, 1]

if schedule == "linear":
return start + (end - start) * progress
elif schedule == "cosine":
# Cosine annealing: smooth transition
return end + (start - end) * 0.5 * (1 + math.cos(math.pi * progress))
else:
raise ValueError(f"Unknown schedule: {schedule}. Use 'linear' or 'cosine'.")