Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 58 additions & 37 deletions upgrade_tool/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import subprocess
import sys
from typing import List, Optional
from typing import List, Optional, Tuple

import typer
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn

# Import the concurrent futures module for threading
from concurrent.futures import ThreadPoolExecutor, as_completed

# Import the refactored utility functions
from .utils import get_outdated_packages, generate_packages_table

Expand All @@ -19,6 +22,29 @@
add_completion=False,
)

def upgrade_package(pkg: dict) -> Tuple[str, str, bool]:
"""
Worker function to upgrade a single package in a separate thread.

Args:
pkg: A dictionary containing package information ('name', 'latest_version').

Returns:
A tuple containing (package_name, latest_version, success_boolean).
"""
pkg_name = pkg['name']
latest_version = pkg['latest_version']
try:
# Execute the pip upgrade command, suppressing output
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "--upgrade", pkg_name],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return pkg_name, latest_version, True
except subprocess.CalledProcessError:
return pkg_name, latest_version, False

@app.command()
def upgrade(
packages_to_upgrade: Optional[List[str]] = typer.Argument(
Expand All @@ -33,46 +59,43 @@ def upgrade(
dry_run: bool = typer.Option(
False, "--dry-run", help="Simulate the upgrade without making any changes."
),
workers: int = typer.Option(
10, "--workers", "-w", help="Number of concurrent workers for parallel upgrades."
)
):
"""
Checks for and upgrades outdated Python packages.
Checks for and concurrently upgrades outdated Python packages.
"""
# Use the utility function to get outdated packages
# --- Filtering Logic (Unchanged) ---
outdated_packages = get_outdated_packages()

if not outdated_packages:
console.print("[bold green]✨ All packages are up to date! ✨[/bold green]")
raise typer.Exit()

# --- Filtering Logic ---
if packages_to_upgrade:
# User specified which packages to upgrade
name_to_pkg = {pkg['name'].lower(): pkg for pkg in outdated_packages}
target_packages = [name_to_pkg[name.lower()] for name in packages_to_upgrade if name.lower() in name_to_pkg]
else:
# Default to all outdated packages
target_packages = outdated_packages

if exclude:
# Exclude packages specified by the user (case-insensitive)
exclude_set = {name.lower() for name in exclude}
target_packages = [pkg for pkg in target_packages if pkg['name'].lower() not in exclude_set]

if not target_packages:
console.print("[bold yellow]No packages match the specified criteria for upgrade.[/bold yellow]")
raise typer.Exit()

# --- Display and Confirmation ---
# Use the utility function to generate the table
# --- Display and Confirmation (Unchanged) ---
table = generate_packages_table(target_packages, title="Outdated Python Packages")
console.print(table)

if dry_run:
console.print("\n[bold yellow]--dry-run enabled. No packages will be upgraded.[/bold yellow]")
console.print(f"\n[bold yellow]--dry-run enabled. Would upgrade {len(target_packages)} packages with {workers} workers.[/bold yellow]")
raise typer.Exit()

if not yes:
# Use Typer's confirmation prompt
try:
confirmed = typer.confirm("\nProceed with the upgrade?")
if not confirmed:
Expand All @@ -82,50 +105,48 @@ def upgrade(
console.print("\nUpgrade cancelled by user.")
raise typer.Exit()

# --- Execution Logic ---
console.print("\n[bold blue]Starting upgrade process...[/bold blue]")
# --- Concurrent Execution Logic (The New Engine) ---
console.print(f"\n[bold blue]Starting parallel upgrade with {workers} workers...[/bold blue]")

# Define a rich Progress bar
progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
console=console,
console=console
)

success_count = 0
fail_count = 0

with progress:
upgrade_task = progress.add_task("[green]Upgrading...", total=len(target_packages))
success_count = 0
fail_count = 0

for pkg in target_packages:
pkg_name = pkg['name']
progress.update(upgrade_task, description=f"Upgrading [bold cyan]{pkg_name}[/bold cyan]...")
# Create a thread pool with the specified number of workers
with ThreadPoolExecutor(max_workers=workers) as executor:
# Submit an upgrade task for each package
future_to_pkg = {executor.submit(upgrade_package, pkg): pkg for pkg in target_packages}

try:
# Execute the pip upgrade command
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "--upgrade", pkg_name],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
progress.console.print(f" ✅ [green]Successfully upgraded {pkg_name} to {pkg['latest_version']}[/green]")
success_count += 1
except subprocess.CalledProcessError:
progress.console.print(f" ❌ [red]Failed to upgrade {pkg_name}[/red]")
fail_count += 1

progress.advance(upgrade_task)
# Process results as they complete
for future in as_completed(future_to_pkg):
pkg_name, latest_version, success = future.result()

if success:
progress.console.print(f" ✅ [green]Successfully upgraded {pkg_name} to {latest_version}[/green]")
success_count += 1
else:
progress.console.print(f" ❌ [red]Failed to upgrade {pkg_name}[/red]")
fail_count += 1

# Advance the progress bar for each completed task
progress.advance(upgrade_task)

# --- Summary Report ---
# --- Summary Report (Unchanged) ---
console.print("\n--- [bold]Upgrade Complete[/bold] ---")
console.print(f"[green]Successfully upgraded:[/green] {success_count} packages")
if fail_count > 0:
console.print(f"[red]Failed to upgrade:[/red] {fail_count} packages")
console.print("--------------------------")


# This makes the script runnable directly, though it's meant to be installed via the entry point
if __name__ == "__main__":
app()