Skip to content
20 changes: 8 additions & 12 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from pathlib import Path

from codeflash.cli_cmds import logging_config
from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.cli_common import apologize_and_exit, get_git_repo_or_none, parse_config_file_or_exit
from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
from codeflash.cli_cmds.console import logger
from codeflash.cli_cmds.extension import install_vscode_extension
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.version import __version__ as version

Expand Down Expand Up @@ -163,10 +162,7 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:


def process_pyproject_config(args: Namespace) -> Namespace:
try:
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
except ValueError as e:
exit_with_message(f"Error parsing config file: {e}", error_on_exit=True)
pyproject_config, pyproject_file_path = parse_config_file_or_exit(args.config_file)
supported_keys = [
"module_root",
"tests_root",
Expand Down Expand Up @@ -248,21 +244,21 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
no_pr = getattr(args, "no_pr", False)

if not no_pr:
import git

from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
from codeflash.code_utils.github_utils import require_github_app_or_exit

# Ensure that the user can actually open PRs on the repo.
try:
git_repo = git.Repo(search_parent_directories=True)
except git.exc.InvalidGitRepositoryError:
maybe_git_repo = get_git_repo_or_none()
if maybe_git_repo is None:
mode = "--all" if hasattr(args, "all") else "--file"
logger.exception(
logger.error(
f"I couldn't find a git repository in the current directory. "
f"I need a git repository to run {mode} and open PRs for optimizations. Exiting..."
)
apologize_and_exit()
# After None check and apologize_and_exit(), we know git_repo is not None
git_repo = maybe_git_repo
assert git_repo is not None # For mypy
git_remote = getattr(args, "git_remote", None)
if not check_and_push_branch(git_repo, git_remote=git_remote):
exit_with_message("Branch is not pushed...", error_on_exit=True)
Expand Down
119 changes: 43 additions & 76 deletions codeflash/cli_cmds/cli_common.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

import shutil
import sys
from typing import Callable, cast

import click
import inquirer
from typing import TYPE_CHECKING, Any, Optional

from codeflash.cli_cmds.console import console, logger

if TYPE_CHECKING:
from pathlib import Path

from git import Repo


def apologize_and_exit() -> None:
console.rule()
Expand All @@ -20,78 +21,44 @@ def apologize_and_exit() -> None:
sys.exit(1)


def inquirer_wrapper(func: Callable[..., str | bool], *args: str | bool, **kwargs: str | bool) -> str | bool:
new_args = []
new_kwargs = {}

if len(args) == 1:
message = str(args[0])
else:
message = str(kwargs["message"])
new_kwargs = kwargs.copy()
split_messages = split_string_to_cli_width(message, is_confirm=func == inquirer.confirm)
for split_message in split_messages[:-1]:
click.echo(split_message)

last_message = split_messages[-1]

if len(args) == 1:
new_args.append(last_message)
else:
new_kwargs["message"] = last_message

return func(*new_args, **new_kwargs)


def split_string_to_cli_width(string: str, is_confirm: bool = False) -> list[str]: # noqa: FBT001, FBT002
cli_width, _ = shutil.get_terminal_size()
# split string to lines that accommodate "[?] " prefix
cli_width -= len("[?] ")
lines = split_string_to_fit_width(string, cli_width)
def get_git_repo_or_none(search_path: Optional[Path] = None) -> Optional[Repo]:
"""Get git repository or None if not in a git repo."""
import git

# split last line to additionally accommodate ": " or " (y/N): " suffix
cli_width -= len(" (y/N):") if is_confirm else len(": ")
last_lines = split_string_to_fit_width(lines[-1], cli_width)
try:
if search_path:
return git.Repo(search_path, search_parent_directories=True)
return git.Repo(search_parent_directories=True)
except git.InvalidGitRepositoryError:
return None

lines = lines[:-1] + last_lines

if len(lines) > 1:
for i in range(len(lines[:-1])):
# Add yellow color to question mark in "[?] " prefix
lines[i] = "[\033[33m?\033[0m] " + lines[i]
return lines


def inquirer_wrapper_path(*args: str, **kwargs: str) -> dict[str, str] | None:
new_args = []
message = kwargs["message"]
new_kwargs = kwargs.copy()
split_messages = split_string_to_cli_width(message)
for split_message in split_messages[:-1]:
click.echo(split_message)

last_message = split_messages[-1]
new_kwargs["message"] = last_message
new_args.append(args[0])

return cast("dict[str, str]", inquirer.prompt([inquirer.Path(*new_args, **new_kwargs)]))


def split_string_to_fit_width(string: str, width: int) -> list[str]:
words = string.split()
lines = []
current_line = [words[0]]
current_length = len(words[0])

for word in words[1:]:
word_length = len(word)
if current_length + word_length + 1 <= width:
current_line.append(word)
current_length += word_length + 1
def require_git_repo_or_exit(search_path: Optional[Path] = None, error_message: Optional[str] = None) -> Repo:
"""Get git repository or exit with error."""
repo = get_git_repo_or_none(search_path)
if repo is None:
if error_message:
logger.error(error_message)
else:
lines.append(" ".join(current_line))
current_line = [word]
current_length = word_length

lines.append(" ".join(current_line))
return lines
logger.error(
"I couldn't find a git repository in the current directory. "
"A git repository is required for this operation."
)
apologize_and_exit()
# After checking for None and calling apologize_and_exit(), we know repo is not None
# but mypy doesn't understand apologize_and_exit() never returns, so we assert
assert repo is not None
return repo


def parse_config_file_or_exit(config_file: Optional[Path] = None, **kwargs: Any) -> tuple[dict[str, Any], Path]: # noqa: ANN401
"""Parse config file or exit with error."""
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.config_parser import parse_config_file

try:
return parse_config_file(config_file, **kwargs)
except ValueError as e:
exit_with_message(f"Error parsing config file: {e}", error_on_exit=True)
# exit_with_message never returns when error_on_exit=True, but mypy doesn't know that
raise # pragma: no cover
Loading
Loading