diff --git a/examples/run_music_generation.py b/examples/run_music_generation.py index 84924e7..cb61d62 100644 --- a/examples/run_music_generation.py +++ b/examples/run_music_generation.py @@ -41,7 +41,13 @@ def parse_args(): parser.add_argument("--max_audio_length_ms", type=int, default=240_000) parser.add_argument("--topk", type=int, default=50) - parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0, + help="Sampling temperature (also used as start value for dynamic scheduling)") + parser.add_argument("--temperature_end", type=float, default=None, + help="Ending temperature for dynamic scheduling") + parser.add_argument("--temperature_schedule", type=str, default="linear", + choices=["linear", "cosine"], + help="Temperature interpolation schedule") parser.add_argument("--cfg_scale", type=float, default=1.5) parser.add_argument("--mula_device", type=str2device, default="cuda") parser.add_argument("--codec_device", type=str2device, default="cuda") @@ -66,6 +72,17 @@ def parse_args(): version=args.version, lazy_load=args.lazy_load, ) + + if args.temperature_end is not None: + temperature = { + "start": args.temperature, + "end": args.temperature_end, + "schedule": args.temperature_schedule, + } + print(f"Using dynamic temperature: {args.temperature} -> {args.temperature_end} ({args.temperature_schedule})") + else: + temperature = args.temperature + with torch.no_grad(): pipe( { @@ -75,7 +92,7 @@ def parse_args(): max_audio_length_ms=args.max_audio_length_ms, save_path=args.save_path, topk=args.topk, - temperature=args.temperature, + temperature=temperature, cfg_scale=args.cfg_scale, ) print(f"Generated music saved to {args.save_path}") diff --git a/src/heartlib/pipelines/music_generation.py b/src/heartlib/pipelines/music_generation.py index c9111ff..cf3a52b 100644 --- a/src/heartlib/pipelines/music_generation.py +++ b/src/heartlib/pipelines/music_generation.py @@ -1,6 +1,7 @@ from tokenizers import Tokenizer from ..heartmula.modeling_heartmula import HeartMuLa from ..heartcodec.modeling_heartcodec import HeartCodec +from ..utils.temperature_schedule import parse_temperature_spec, compute_temperature import torch from typing import Dict, Any, Optional, Union import os @@ -182,9 +183,12 @@ def _unload(self): def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {"cfg_scale": kwargs.get("cfg_scale", 1.5)} + temperature = kwargs.get("temperature", 1.0) + temp_config = parse_temperature_spec(temperature) + forward_kwargs = { "max_audio_length_ms": kwargs.get("max_audio_length_ms", 120_000), - "temperature": kwargs.get("temperature", 1.0), + "temp_config": temp_config, "topk": kwargs.get("topk", 50), "cfg_scale": kwargs.get("cfg_scale", 1.5), } @@ -268,7 +272,7 @@ def _forward( self, model_inputs: Dict[str, Any], max_audio_length_ms: int, - temperature: float, + temp_config: Dict[str, Any], topk: int, cfg_scale: float, ): @@ -286,7 +290,7 @@ def _forward( tokens=prompt_tokens, tokens_mask=prompt_tokens_mask, input_pos=prompt_pos, - temperature=temperature, + temperature=temp_config["start"], topk=topk, cfg_scale=cfg_scale, continuous_segments=continuous_segment, @@ -314,6 +318,14 @@ def _pad_audio_token(token: torch.Tensor): max_audio_frames = max_audio_length_ms // 80 for i in tqdm(range(max_audio_frames)): + progress = i / max_audio_frames + current_temp = compute_temperature( + progress, + temp_config["start"], + temp_config["end"], + temp_config["schedule"], + ) + curr_token, curr_token_mask = _pad_audio_token(curr_token) with torch.autocast( device_type=self.mula_device.type, dtype=self.mula_dtype @@ -322,7 +334,7 @@ def _pad_audio_token(token: torch.Tensor): tokens=curr_token, tokens_mask=curr_token_mask, input_pos=prompt_pos[..., -1:] + i + 1, - temperature=temperature, + temperature=current_temp, topk=topk, cfg_scale=cfg_scale, continuous_segments=None, diff --git a/src/heartlib/utils/__init__.py b/src/heartlib/utils/__init__.py new file mode 100644 index 0000000..abd8a72 --- /dev/null +++ b/src/heartlib/utils/__init__.py @@ -0,0 +1,12 @@ +"""Utility functions for HeartLib.""" +from .temperature_schedule import ( + TemperatureSpec, + parse_temperature_spec, + compute_temperature, +) + +__all__ = [ + "TemperatureSpec", + "parse_temperature_spec", + "compute_temperature", +] diff --git a/src/heartlib/utils/temperature_schedule.py b/src/heartlib/utils/temperature_schedule.py new file mode 100644 index 0000000..83830c8 --- /dev/null +++ b/src/heartlib/utils/temperature_schedule.py @@ -0,0 +1,80 @@ +"""Temperature scheduling utility + +This module helps adjust temperature throughout an inference run. + +The purpose of this is to allow inference to with a higher temperature and gradually reduce it to mitigate precision loss accumulation for longer runs. + +""" +from typing import Union, Tuple, Dict, Any +import math + +TemperatureSpec = Union[float, Tuple[float, float], Dict[str, Any]] + + +def parse_temperature_spec(temp_spec: TemperatureSpec) -> Dict[str, Any]: + """ + Parse temperature specification into normalized config. + + Args: + temp_spec: Either a float, tuple (start, end), or dict with keys: + - start: Starting temperature + - end: Ending temperature + - schedule: "linear" or "cosine" (default: "linear") + + Returns: + Dict with keys: start, end, schedule, is_dynamic + """ + if isinstance(temp_spec, (int, float)): + return { + "start": float(temp_spec), + "end": float(temp_spec), + "schedule": "linear", + "is_dynamic": False, + } + elif isinstance(temp_spec, tuple) and len(temp_spec) == 2: + return { + "start": float(temp_spec[0]), + "end": float(temp_spec[1]), + "schedule": "linear", + "is_dynamic": True, + } + elif isinstance(temp_spec, dict): + start = float(temp_spec.get("start", 1.0)) + end = float(temp_spec.get("end", start)) + return { + "start": start, + "end": end, + "schedule": temp_spec.get("schedule", "linear"), + "is_dynamic": start != end, + } + else: + raise ValueError(f"Invalid temperature spec: {temp_spec}") + + +def compute_temperature( + progress: float, + start: float, + end: float, + schedule: str = "linear", +) -> float: + """ + Compute temperature at a given progress point. + + Args: + progress: Generation progress from 0.0 to 1.0 + start: Starting temperature + end: Ending temperature + schedule: Interpolation method ("linear" or "cosine") + + Returns: + Interpolated temperature value + """ + progress = max(0.0, min(1.0, progress)) # Clamp to [0, 1] + + if schedule == "linear": + return start + (end - start) * progress + elif schedule == "cosine": + # Cosine annealing: smooth transition + return end + (start - end) * 0.5 * (1 + math.cos(math.pi * progress)) + else: + raise ValueError(f"Unknown schedule: {schedule}. Use 'linear' or 'cosine'.")