diff --git a/src/together/cli/api/finetune.py b/src/together/cli/api/finetune.py index c14081e..dbd2eed 100644 --- a/src/together/cli/api/finetune.py +++ b/src/together/cli/api/finetune.py @@ -17,6 +17,8 @@ DownloadCheckpointType, FinetuneEventType, FinetuneTrainingLimits, + FullTrainingType, + LoRATrainingType, ) from together.utils import ( finetune_price_to_dollars, @@ -29,13 +31,21 @@ _CONFIRMATION_MESSAGE = ( "You are about to create a fine-tuning job. " - "The cost of your job will be determined by the model size, the number of tokens " + "The estimated price of this job is {price}. " + "The actual cost of your job will be determined by the model size, the number of tokens " "in the training file, the number of tokens in the validation file, the number of epochs, and " - "the number of evaluations. Visit https://www.together.ai/pricing to get a price estimate.\n" + "the number of evaluations. Visit https://www.together.ai/pricing to learn more about fine-tuning pricing.\n" + "{warning}" "You can pass `-y` or `--confirm` to your command to skip this message.\n\n" "Do you want to proceed?" ) +_WARNING_MESSAGE_INSUFFICIENT_FUNDS = ( + "The estimated price of this job is significantly greater than your current credit limit and balance combined. " + "It will likely get cancelled due to insufficient funds. " + "Consider increasing your credit limit at https://api.together.xyz/settings/profile\n" +) + class DownloadCheckpointTypeChoice(click.Choice): def __init__(self) -> None: @@ -357,12 +367,36 @@ def create( "You have specified a number of evaluation loops but no validation file." ) - if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True): + finetune_price_estimation_result = client.fine_tuning.estimate_price( + training_file=training_file, + validation_file=validation_file, + model=model, + n_epochs=n_epochs, + n_evals=n_evals, + training_type="lora" if lora else "full", + training_method=training_method, + ) + + price = click.style( + f"${finetune_price_estimation_result.estimated_total_price:.2f}", + bold=True, + ) + + if not finetune_price_estimation_result.allowed_to_proceed: + warning = click.style(_WARNING_MESSAGE_INSUFFICIENT_FUNDS, fg="red", bold=True) + else: + warning = "" + + confirmation_message = _CONFIRMATION_MESSAGE.format( + price=price, + warning=warning, + ) + + if confirm or click.confirm(confirmation_message, default=True, show_default=True): response = client.fine_tuning.create( **training_args, verbose=True, ) - report_string = f"Successfully submitted a fine-tuning job {response.id}" if response.created_at is not None: created_time = datetime.strptime( diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 2b3a652..7cd1eb0 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -20,6 +20,8 @@ FinetuneLRScheduler, FinetuneRequest, FinetuneResponse, + FinetunePriceEstimationRequest, + FinetunePriceEstimationResponse, FinetuneTrainingLimits, FullTrainingType, LinearLRScheduler, @@ -31,7 +33,7 @@ TrainingMethodSFT, TrainingType, ) -from together.types.finetune import DownloadCheckpointType +from together.types.finetune import DownloadCheckpointType, TrainingMethod from together.utils import log_warn_once, normalize_key @@ -42,6 +44,12 @@ TrainingMethodSFT().method, TrainingMethodDPO().method, } +_WARNING_MESSAGE_INSUFFICIENT_FUNDS = ( + "The estimated price of the fine-tuning job is {} which is significantly " + "greater than your current credit limit and balance combined. " + "It will likely get cancelled due to insufficient funds. " + "Proceed at your own risk." +) def create_finetune_request( @@ -473,12 +481,34 @@ def create( hf_api_token=hf_api_token, hf_output_repo_name=hf_output_repo_name, ) + if from_checkpoint is None and from_hf_model is None: + price_estimation_result = self.estimate_price( + training_file=training_file, + validation_file=validation_file, + model=model_name, + n_epochs=finetune_request.n_epochs, + n_evals=finetune_request.n_evals, + training_type="lora" if lora else "full", + training_method=training_method, + ) + price_limit_passed = price_estimation_result.allowed_to_proceed + else: + # unsupported case + price_limit_passed = True if verbose: rprint( "Submitting a fine-tuning job with the following parameters:", finetune_request, ) + if not price_limit_passed: + rprint( + "[red]" + + _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format( + price_estimation_result.estimated_total_price + ) + + "[/red]", + ) parameter_payload = finetune_request.model_dump(exclude_none=True) response, _, _ = requestor.request( @@ -493,6 +523,81 @@ def create( return FinetuneResponse(**response.data) + def estimate_price( + self, + *, + training_file: str, + model: str, + validation_file: str | None = None, + n_epochs: int | None = 1, + n_evals: int | None = 0, + training_type: str = "lora", + training_method: str = "sft", + ) -> FinetunePriceEstimationResponse: + """ + Estimates the price of a fine-tuning job + + Args: + training_file (str): File-ID of a file uploaded to the Together API + model (str): Name of the base model to run fine-tune job on + validation_file (str, optional): File ID of a file uploaded to the Together API for validation. + n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1. + n_evals (int, optional): Number of evaluation loops to run. Defaults to 0. + training_type (str, optional): Training type. Defaults to "lora". + training_method (str, optional): Training method. Defaults to "sft". + + Returns: + FinetunePriceEstimationResponse: Object containing the price estimation result. + """ + training_type_cls: TrainingType + training_method_cls: TrainingMethod + + if training_method == "sft": + training_method_cls = TrainingMethodSFT(method="sft") + elif training_method == "dpo": + training_method_cls = TrainingMethodDPO(method="dpo") + else: + raise ValueError(f"Unknown training method: {training_method}") + + if training_type.lower() == "lora": + # parameters of lora are unused in price estimation + # but we need to set them to valid values + training_type_cls = LoRATrainingType( + type="Lora", + lora_r=16, + lora_alpha=16, + lora_dropout=0.0, + lora_trainable_modules="all-linear", + ) + elif training_type.lower() == "full": + training_type_cls = FullTrainingType(type="Full") + else: + raise ValueError(f"Unknown training type: {training_type}") + + request = FinetunePriceEstimationRequest( + training_file=training_file, + validation_file=validation_file, + model=model, + n_epochs=n_epochs, + n_evals=n_evals, + training_type=training_type_cls, + training_method=training_method_cls, + ) + parameter_payload = request.model_dump(exclude_none=True) + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + response, _, _ = requestor.request( + options=TogetherRequest( + method="POST", url="fine-tunes/estimate-price", params=parameter_payload + ), + stream=False, + ) + assert isinstance(response, TogetherResponse) + + return FinetunePriceEstimationResponse(**response.data) + def list(self) -> FinetuneList: """ Lists fine-tune job history @@ -941,11 +1046,34 @@ async def create( hf_output_repo_name=hf_output_repo_name, ) + if from_checkpoint is None and from_hf_model is None: + price_estimation_result = await self.estimate_price( + training_file=training_file, + validation_file=validation_file, + model=model_name, + n_epochs=finetune_request.n_epochs, + n_evals=finetune_request.n_evals, + training_type="lora" if lora else "full", + training_method=training_method, + ) + price_limit_passed = price_estimation_result.allowed_to_proceed + else: + # unsupported case + price_limit_passed = True + if verbose: rprint( "Submitting a fine-tuning job with the following parameters:", finetune_request, ) + if not price_limit_passed: + rprint( + "[red]" + + _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format( + price_estimation_result.estimated_total_price + ) + + "[/red]", + ) parameter_payload = finetune_request.model_dump(exclude_none=True) response, _, _ = await requestor.arequest( @@ -961,6 +1089,81 @@ async def create( return FinetuneResponse(**response.data) + async def estimate_price( + self, + *, + training_file: str, + model: str, + validation_file: str | None = None, + n_epochs: int | None = 1, + n_evals: int | None = 0, + training_type: str = "lora", + training_method: str = "sft", + ) -> FinetunePriceEstimationResponse: + """ + Estimates the price of a fine-tuning job + + Args: + training_file (str): File-ID of a file uploaded to the Together API + model (str): Name of the base model to run fine-tune job on + validation_file (str, optional): File ID of a file uploaded to the Together API for validation. + n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1. + n_evals (int, optional): Number of evaluation loops to run. Defaults to 0. + training_type (str, optional): Training type. Defaults to "lora". + training_method (str, optional): Training method. Defaults to "sft". + + Returns: + FinetunePriceEstimationResponse: Object containing the price estimation result. + """ + training_type_cls: TrainingType + training_method_cls: TrainingMethod + + if training_method == "sft": + training_method_cls = TrainingMethodSFT(method="sft") + elif training_method == "dpo": + training_method_cls = TrainingMethodDPO(method="dpo") + else: + raise ValueError(f"Unknown training method: {training_method}") + + if training_type.lower() == "lora": + # parameters of lora are unused in price estimation + # but we need to set them to valid values + training_type_cls = LoRATrainingType( + type="Lora", + lora_r=16, + lora_alpha=16, + lora_dropout=0.0, + lora_trainable_modules="all-linear", + ) + elif training_type.lower() == "full": + training_type_cls = FullTrainingType(type="Full") + else: + raise ValueError(f"Unknown training type: {training_type}") + + request = FinetunePriceEstimationRequest( + training_file=training_file, + validation_file=validation_file, + model=model, + n_epochs=n_epochs, + n_evals=n_evals, + training_type=training_type_cls, + training_method=training_method_cls, + ) + parameter_payload = request.model_dump(exclude_none=True) + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + response, _, _ = await requestor.arequest( + options=TogetherRequest( + method="POST", url="fine-tunes/estimate-price", params=parameter_payload + ), + stream=False, + ) + assert isinstance(response, TogetherResponse) + + return FinetunePriceEstimationResponse(**response.data) + async def list(self) -> FinetuneList: """ Async method to list fine-tune job history diff --git a/src/together/types/__init__.py b/src/together/types/__init__.py index f4dd737..61c054a 100644 --- a/src/together/types/__init__.py +++ b/src/together/types/__init__.py @@ -54,6 +54,8 @@ FinetuneListEvents, FinetuneRequest, FinetuneResponse, + FinetunePriceEstimationRequest, + FinetunePriceEstimationResponse, FinetuneDeleteResponse, FinetuneTrainingLimits, FullTrainingType, @@ -103,6 +105,8 @@ "FinetuneDeleteResponse", "FinetuneDownloadResult", "FinetuneLRScheduler", + "FinetunePriceEstimationRequest", + "FinetunePriceEstimationResponse", "LinearLRScheduler", "LinearLRSchedulerArgs", "CosineLRScheduler", diff --git a/src/together/types/finetune.py b/src/together/types/finetune.py index 52c802b..286932e 100644 --- a/src/together/types/finetune.py +++ b/src/together/types/finetune.py @@ -308,6 +308,32 @@ def validate_training_type(cls, v: TrainingType) -> TrainingType: raise ValueError("Unknown training type") +class FinetunePriceEstimationRequest(BaseModel): + """ + Fine-tune price estimation request type + """ + + training_file: str + validation_file: str | None = None + model: str + n_epochs: int + n_evals: int + training_type: TrainingType + training_method: TrainingMethod + + +class FinetunePriceEstimationResponse(BaseModel): + """ + Fine-tune price estimation response type + """ + + estimated_total_price: float + user_limit: float + estimated_train_token_count: int + estimated_eval_token_count: int + allowed_to_proceed: bool + + class FinetuneList(BaseModel): # object type object: Literal["list"] | None = None diff --git a/tests/unit/test_finetune_resources.py b/tests/unit/test_finetune_resources.py index b72e5b1..6020a0c 100644 --- a/tests/unit/test_finetune_resources.py +++ b/tests/unit/test_finetune_resources.py @@ -1,6 +1,10 @@ import pytest +from unittest.mock import MagicMock, Mock, patch +from together.client import Together from together.resources.finetune import create_finetune_request +from together.together_response import TogetherResponse +from together.types import TogetherRequest from together.types.finetune import ( FinetuneFullTrainingLimits, FinetuneLoraTrainingLimits, @@ -12,6 +16,7 @@ _TRAINING_FILE = "file-7dbce5e9-7993-4520-9f3e-a7ece6c39d84" _VALIDATION_FILE = "file-7dbce5e9-7553-4520-9f3e-a7ece6c39d84" _FROM_CHECKPOINT = "ft-12345678-1234-1234-1234-1234567890ab" +_DUMMY_ID = "ft-12345678-1234-1234-1234-1234567890ab" _MODEL_LIMITS = FinetuneTrainingLimits( max_num_epochs=20, max_learning_rate=1.0, @@ -31,6 +36,43 @@ ) +def mock_request(options: TogetherRequest, *args, **kwargs): + if options.url == "fine-tunes/estimate-price": + return ( + TogetherResponse( + data={ + "estimated_total_price": 100, + "allowed_to_proceed": True, + "estimated_train_token_count": 1000, + "estimated_eval_token_count": 100, + "user_limit": 1000, + }, + headers={}, + ), + None, + None, + ) + elif options.url == "fine-tunes": + return ( + TogetherResponse( + data={ + "id": _DUMMY_ID, + }, + headers={}, + ), + None, + None, + ) + elif options.url == "fine-tunes/models/limits": + return ( + TogetherResponse(data=_MODEL_LIMITS.model_dump(), headers={}), + None, + None, + ) + else: + raise ValueError(f"Unknown URL: {options.url}") + + def test_simple_request(): request = create_finetune_request( model_limits=_MODEL_LIMITS, @@ -335,3 +377,78 @@ def test_train_on_inputs_not_supported_for_dpo(): training_method="dpo", train_on_inputs=True, ) + + +def test_price_estimation_request(mocker): + mock_requestor = Mock() + mock_requestor.request = MagicMock() + mock_requestor.request.side_effect = mock_request + mocker.patch( + "together.abstract.api_requestor.APIRequestor", return_value=mock_requestor + ) + test_data = [ + { + "training_type": "lora", + "training_method": "sft", + }, + { + "training_type": "lora", + "training_method": "dpo", + }, + { + "training_type": "full", + "training_method": "sft", + }, + ] + client = Together(api_key="fake_api_key") + for test_case in test_data: + response = client.fine_tuning.estimate_price( + training_file=_TRAINING_FILE, + model=_MODEL_NAME, + validation_file=_VALIDATION_FILE, + n_epochs=1, + n_evals=0, + training_type=test_case["training_type"], + training_method=test_case["training_method"], + ) + assert response.estimated_total_price > 0 + assert response.allowed_to_proceed + assert response.estimated_train_token_count > 0 + assert response.estimated_eval_token_count > 0 + + +def test_create_ft_job(mocker): + mock_requestor = Mock() + mock_requestor.request = MagicMock() + mock_requestor.request.side_effect = mock_request + mocker.patch( + "together.abstract.api_requestor.APIRequestor", return_value=mock_requestor + ) + + client = Together(api_key="fake_api_key") + response = client.fine_tuning.create( + training_file=_TRAINING_FILE, + model=_MODEL_NAME, + validation_file=_VALIDATION_FILE, + n_epochs=1, + n_evals=0, + lora=True, + training_method="sft", + ) + + assert mock_requestor.request.call_count == 3 + assert response.id == _DUMMY_ID + + response = client.fine_tuning.create( + training_file=_TRAINING_FILE, + model=None, + validation_file=_VALIDATION_FILE, + n_epochs=1, + n_evals=0, + lora=True, + training_method="sft", + from_checkpoint=_FROM_CHECKPOINT, + ) + + assert mock_requestor.request.call_count == 5 + assert response.id == _DUMMY_ID