diff --git a/cmd/apply.go b/cmd/apply.go index 3b14c5b..6476106 100644 --- a/cmd/apply.go +++ b/cmd/apply.go @@ -132,7 +132,10 @@ func runApply(stackName, composeFile string, opts *ApplyOptions) error { stackDeployer := swarm.NewStackDeployer(cli, stackName, 3) // Create snapshot before deployment - snap := snapshot.CreateSnapshot(ctx, stackDeployer) + snap, err := snapshot.CreateSnapshot(ctx, stackDeployer) + if err != nil { + return fmt.Errorf("deployment blocked: %w", err) + } // Track deployment state deploymentComplete := make(chan bool, 1) @@ -148,7 +151,10 @@ func runApply(stackName, composeFile string, opts *ApplyOptions) error { os.Exit(0) default: log.Println("Deployment interrupted, initiating rollback...") - snapshot.Rollback(context.Background(), stackDeployer, snap) + // Create context with timeout for rollback to prevent hanging + rollbackCtx, rollbackCancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer rollbackCancel() + snapshot.Rollback(rollbackCtx, stackDeployer, snap) os.Exit(130) } }() @@ -348,12 +354,12 @@ func monitorServiceTasks(ctx context.Context, cli *client.Client, svc swarm.Serv } // waitForAllTasksHealthy waits for all tasks of updated services to become healthy +// Optimized to use batch API calls instead of per-service calls func waitForAllTasksHealthy(ctx context.Context, cli *client.Client, stackName string, updatedServices []swarm.ServiceUpdateResult, deployID string) error { ticker := time.NewTicker(2 * time.Second) defer ticker.Stop() startTime := time.Now() - serviceHealthyCount := make(map[string]int) for { select { @@ -364,62 +370,74 @@ func waitForAllTasksHealthy(ctx context.Context, cli *client.Client, stackName s case <-ticker.C: allHealthy := true unhealthyTasks := []string{} + serviceHealthyCount := make(map[string]int) + + // OPTIMIZATION 1: Batch fetch all services in stack with one API call + serviceFilter := filters.NewArgs() + serviceFilter.Add("label", "com.docker.stack.namespace="+stackName) + allServices, err := cli.ServiceList(ctx, types.ServiceListOptions{ + Filters: serviceFilter, + }) + if err != nil { + log.Printf("[HealthCheck] Failed to list services: %v", err) + allHealthy = false + continue + } - for _, svc := range updatedServices { - // Get current service by name to get updated service ID - // During updates, service ID remains the same but this ensures we have the latest service state - serviceFilter := filters.NewArgs() - serviceFilter.Add("name", svc.ServiceName) - services, err := cli.ServiceList(ctx, types.ServiceListOptions{ - Filters: serviceFilter, - }) - if err != nil { - log.Printf("[HealthCheck] Failed to get service %s: %v", svc.ServiceName, err) - allHealthy = false - continue - } - if len(services) == 0 { - log.Printf("[HealthCheck] Service %s not found", svc.ServiceName) - allHealthy = false - continue - } + // Create service name -> service map for quick lookup + serviceMap := make(map[string]dockerswarm.Service) + for _, svc := range allServices { + serviceMap[svc.Spec.Name] = svc + } + + // OPTIMIZATION 2: Batch fetch all tasks in stack with one API call + taskFilter := filters.NewArgs() + taskFilter.Add("label", "com.docker.stack.namespace="+stackName) + allStackTasks, err := cli.TaskList(ctx, types.TaskListOptions{ + Filters: taskFilter, + }) + if err != nil { + log.Printf("[HealthCheck] Failed to list tasks: %v", err) + allHealthy = false + continue + } - currentServiceID := services[0].ID + // Group tasks by service name and filter by deployID + tasksByService := make(map[string][]dockerswarm.Task) + for _, task := range allStackTasks { + // Filter by deployID + if task.Spec.ContainerSpec != nil && task.Spec.ContainerSpec.Labels != nil { + if taskDeployID, ok := task.Spec.ContainerSpec.Labels["com.stackman.deploy.id"]; ok && taskDeployID == deployID { + serviceName := task.Spec.ContainerSpec.Labels["com.docker.swarm.service.name"] + tasksByService[serviceName] = append(tasksByService[serviceName], task) + } + } + } - // Get ALL tasks for this service - // Note: Docker API does not support label filtering for tasks, only for containers - // So we get all tasks and filter manually - filter := filters.NewArgs() - filter.Add("service", currentServiceID) + // OPTIMIZATION 3: Collect containers that need inspection + type containerTask struct { + containerID string + task dockerswarm.Task + serviceName string + } + containersToInspect := []containerTask{} - allTasks, err := cli.TaskList(ctx, types.TaskListOptions{ - Filters: filter, - }) - if err != nil { - log.Printf("[HealthCheck] Failed to list tasks for service %s: %v", svc.ServiceName, err) + // Process each updated service + for _, svc := range updatedServices { + _, exists := serviceMap[svc.ServiceName] + if !exists { + log.Printf("[HealthCheck] Service %s not found", svc.ServiceName) allHealthy = false continue } - // Filter tasks by deployID from ContainerSpec labels - tasks := []dockerswarm.Task{} - for _, t := range allTasks { - if t.Spec.ContainerSpec != nil && t.Spec.ContainerSpec.Labels != nil { - if taskDeployID, ok := t.Spec.ContainerSpec.Labels["com.stackman.deploy.id"]; ok && taskDeployID == deployID { - tasks = append(tasks, t) - } - } - } - - log.Printf("[HealthCheck] Service %s: found %d tasks with deployID %s (total tasks: %d)", - svc.ServiceName, len(tasks), deployID, len(allTasks)) + tasks := tasksByService[svc.ServiceName] + log.Printf("[HealthCheck] Service %s: found %d tasks with deployID %s", + svc.ServiceName, len(tasks), deployID) - healthyTaskCount := 0 hasRunningTask := false for _, t := range tasks { - // deployID label guarantees correct tasks - no version check needed - // Log failed/shutdown tasks but don't fail immediately (Docker Swarm may restart) if t.Status.State == dockerswarm.TaskStateFailed || t.Status.State == dockerswarm.TaskStateShutdown || @@ -444,47 +462,81 @@ func waitForAllTasksHealthy(ctx context.Context, cli *client.Client, stackName s continue } - // Check container health if healthcheck is defined + // Mark container for inspection if it exists if t.Status.ContainerStatus != nil && t.Status.ContainerStatus.ContainerID != "" { - containerInfo, err := cli.ContainerInspect(ctx, t.Status.ContainerStatus.ContainerID) - if err != nil { - log.Printf("[HealthCheck] Failed to inspect container %s for task %s (%s): %v", - t.Status.ContainerStatus.ContainerID[:12], t.ID[:12], svc.ServiceName, err) - allHealthy = false - unhealthyTasks = append(unhealthyTasks, fmt.Sprintf("%s/%s (inspect failed)", svc.ServiceName, t.ID[:12])) - continue - } - - // If container has health check, wait for healthy status - if containerInfo.State.Health != nil { - if containerInfo.State.Health.Status != container.Healthy { - allHealthy = false - unhealthyTasks = append(unhealthyTasks, fmt.Sprintf("%s/%s (health: %s)", svc.ServiceName, t.ID[:12], containerInfo.State.Health.Status)) - log.Printf("[HealthCheck] ⏳ Task %s (%s) is %s", t.ID[:12], svc.ServiceName, containerInfo.State.Health.Status) - } else { - log.Printf("[HealthCheck] ✅ Task %s (%s) is healthy", t.ID[:12], svc.ServiceName) - healthyTaskCount++ - } - } else { - // No healthcheck defined, just check if running - log.Printf("[HealthCheck] ✅ Task %s (%s) is running (no healthcheck)", t.ID[:12], svc.ServiceName) - healthyTaskCount++ - } + containersToInspect = append(containersToInspect, containerTask{ + containerID: t.Status.ContainerStatus.ContainerID, + task: t, + serviceName: svc.ServiceName, + }) } else { - // No container status yet allHealthy = false unhealthyTasks = append(unhealthyTasks, fmt.Sprintf("%s/%s (no container)", svc.ServiceName, t.ID[:12])) log.Printf("[HealthCheck] ⏳ Task %s (%s) has no container yet", t.ID[:12], svc.ServiceName) } } - // Track if service has running tasks if !hasRunningTask { log.Printf("[HealthCheck] ⏳ Service %s has no running tasks yet (may be restarting)", svc.ServiceName) allHealthy = false } + } + + // OPTIMIZATION 4: Parallel container inspections with goroutines + type inspectResult struct { + ct containerTask + info types.ContainerJSON + err error + } + + resultChan := make(chan inspectResult, len(containersToInspect)) + var wg sync.WaitGroup + + for _, ct := range containersToInspect { + wg.Add(1) + go func(c containerTask) { + defer wg.Done() + info, err := cli.ContainerInspect(ctx, c.containerID) + resultChan <- inspectResult{ct: c, info: info, err: err} + }(ct) + } - serviceHealthyCount[svc.ServiceName] = healthyTaskCount + go func() { + wg.Wait() + close(resultChan) + }() + + // Process inspection results + for result := range resultChan { + ct := result.ct + taskID := ct.task.ID + serviceName := ct.serviceName + + if result.err != nil { + log.Printf("[HealthCheck] Failed to inspect container %s for task %s (%s): %v", + ct.containerID[:12], taskID[:12], serviceName, result.err) + allHealthy = false + unhealthyTasks = append(unhealthyTasks, fmt.Sprintf("%s/%s (inspect failed)", serviceName, taskID[:12])) + continue + } + + // Check health status + if result.info.State.Health != nil { + if result.info.State.Health.Status != container.Healthy { + allHealthy = false + unhealthyTasks = append(unhealthyTasks, fmt.Sprintf("%s/%s (health: %s)", + serviceName, taskID[:12], result.info.State.Health.Status)) + log.Printf("[HealthCheck] ⏳ Task %s (%s) is %s", + taskID[:12], serviceName, result.info.State.Health.Status) + } else { + log.Printf("[HealthCheck] ✅ Task %s (%s) is healthy", taskID[:12], serviceName) + serviceHealthyCount[serviceName]++ + } + } else { + // No healthcheck defined, just check if running + log.Printf("[HealthCheck] ✅ Task %s (%s) is running (no healthcheck)", taskID[:12], serviceName) + serviceHealthyCount[serviceName]++ + } } // Check that all services have at least one healthy task diff --git a/internal/health/monitor.go b/internal/health/monitor.go index 5299333..b65bd5d 100644 --- a/internal/health/monitor.go +++ b/internal/health/monitor.go @@ -41,8 +41,6 @@ type Monitor struct { doneChan chan struct{} // signals monitor has stopped // Lifecycle management - ctx context.Context - cancel context.CancelFunc shutdownOnce sync.Once stopped bool @@ -57,8 +55,6 @@ func NewMonitor(client client.APIClient, taskID string, serviceID string, servic // NewMonitorWithLogs creates a new task monitor with optional log streaming func NewMonitorWithLogs(client client.APIClient, taskID string, serviceID string, serviceName string, showLogs bool) *Monitor { - ctx, cancel := context.WithCancel(context.Background()) - return &Monitor{ client: client, taskID: taskID, @@ -68,8 +64,6 @@ func NewMonitorWithLogs(client client.APIClient, taskID string, serviceID string eventChan: make(chan Event, 10), stopChan: make(chan struct{}), doneChan: make(chan struct{}), - ctx: ctx, - cancel: cancel, healthStatus: "unknown", lastSeen: time.Now(), } @@ -134,7 +128,6 @@ func (m *Monitor) Stop() { m.stopped = true m.mu.Unlock() close(m.stopChan) - m.cancel() }) } @@ -276,6 +269,12 @@ func (m *Monitor) checkHealth(ctx context.Context) { // streamLogs streams container logs for this task func (m *Monitor) streamLogs(ctx context.Context) { + // If Docker client is not initialized, skip log streaming + if m.client == nil { + log.Printf("[TaskLogs] Docker client is nil, skipping log streaming for task %s", m.shortTaskID()) + return + } + log.Printf("[TaskLogs] Waiting for container ID for task %s...", m.shortTaskID()) // Wait for container ID to be available AND task to be running diff --git a/internal/health/monitor_test.go b/internal/health/monitor_test.go index 742713e..96acff7 100644 --- a/internal/health/monitor_test.go +++ b/internal/health/monitor_test.go @@ -196,14 +196,6 @@ func TestMonitor_Stop(t *testing.T) { t.Error("stopChan should be closed") } - // Context should be cancelled - select { - case <-monitor.ctx.Done(): - // Expected - case <-time.After(100 * time.Millisecond): - t.Error("context should be cancelled") - } - // Multiple calls to Stop should be safe monitor.Stop() monitor.Stop() diff --git a/internal/health/watcher.go b/internal/health/watcher.go index 4c36e94..b79b0a6 100644 --- a/internal/health/watcher.go +++ b/internal/health/watcher.go @@ -162,13 +162,15 @@ func (w *Watcher) Subscribe() <-chan Event { } // Unsubscribe removes a subscriber channel +// The channel will be closed automatically when the Watcher shuts down func (w *Watcher) Unsubscribe(ch <-chan Event) { w.subscribersMu.Lock() defer w.subscribersMu.Unlock() for i, sub := range w.subscribers { if sub == ch { - close(sub) + // Don't close here to avoid race condition with broadcastEvents + // The broadcaster will close all channels on shutdown w.subscribers = append(w.subscribers[:i], w.subscribers[i+1:]...) break } @@ -604,30 +606,52 @@ func (w *Watcher) GetTasksForService(serviceID string) []string { } // SubscribeToService creates a filtered channel that only receives events for a specific service -func (w *Watcher) SubscribeToService(serviceID string) <-chan Event { +// Returns the filtered channel and an unsubscribe function that MUST be called to stop the goroutine +func (w *Watcher) SubscribeToService(serviceID string) (<-chan Event, func()) { // Subscribe to all events allEvents := w.Subscribe() - // Create filtered channel + // Create filtered channel and done signal filtered := make(chan Event, 50) + done := make(chan struct{}) // Start filter goroutine go func() { defer close(filtered) - for event := range allEvents { - if event.ServiceID == serviceID { - select { - case filtered <- event: - // Event sent - default: - // Channel full, drop event - log.Printf("[TaskWatcher] WARNING: Filtered channel full for service %s", serviceID) + defer w.Unsubscribe(allEvents) // Clean up parent subscription + + for { + select { + case <-done: + // Unsubscribe called, stop filtering + return + case event, ok := <-allEvents: + if !ok { + // Parent channel closed, stop filtering + return + } + if event.ServiceID == serviceID { + select { + case filtered <- event: + // Event sent + case <-done: + // Unsubscribe called during send + return + default: + // Channel full, drop event + log.Printf("[TaskWatcher] WARNING: Filtered channel full for service %s", serviceID) + } } } } }() - return filtered + // Return unsubscribe function that stops the filter goroutine + unsubscribe := func() { + close(done) + } + + return filtered, unsubscribe } // pollTasks periodically polls Task API to discover new tasks diff --git a/internal/health/watcher_test.go b/internal/health/watcher_test.go index fcbbb5c..9391982 100644 --- a/internal/health/watcher_test.go +++ b/internal/health/watcher_test.go @@ -51,20 +51,38 @@ func TestWatcher_Subscribe_Unsubscribe(t *testing.T) { t.Errorf("Expected 1 subscriber after unsubscribe, got %d", len(watcher.subscribers)) } - // Verify channel is closed - _, ok := <-ch1 - if ok { - t.Error("Expected unsubscribed channel to be closed") + // Note: Unsubscribe does NOT close the channel immediately to avoid race condition + // with broadcastEvents. The channel will be closed on watcher shutdown. + // We just verify that ch1 is removed from subscribers list. + + // Verify ch1 is not in subscribers anymore + watcher.subscribersMu.RLock() + found := false + for _, sub := range watcher.subscribers { + if sub == ch1 { + found = true + break + } } + watcher.subscribersMu.RUnlock() - // ch2 should still be valid - select { - case _, ok := <-ch2: - if !ok { - t.Error("Expected ch2 to still be open") + if found { + t.Error("Expected ch1 to be removed from subscribers") + } + + // ch2 should still be in subscribers + watcher.subscribersMu.RLock() + found = false + for _, sub := range watcher.subscribers { + if sub == ch2 { + found = true + break } - default: - // Channel is open and empty, which is correct + } + watcher.subscribersMu.RUnlock() + + if !found { + t.Error("Expected ch2 to still be in subscribers") } } diff --git a/internal/snapshot/snapshot.go b/internal/snapshot/snapshot.go index 29b28c1..16877cc 100644 --- a/internal/snapshot/snapshot.go +++ b/internal/snapshot/snapshot.go @@ -9,16 +9,16 @@ import ( "github.com/SomeBlackMagic/stackman/internal/swarm" ) -// createSnapshot creates a snapshot of current stack state before deployment -func CreateSnapshot(ctx context.Context, stackDeployer *swarm.StackDeployer) *swarm.StackSnapshot { +// CreateSnapshot creates a snapshot of current stack state before deployment +// Returns error if snapshot creation fails to prevent deployment without rollback capability +func CreateSnapshot(ctx context.Context, stackDeployer *swarm.StackDeployer) (*swarm.StackSnapshot, error) { log.Println("Creating snapshot of current stack state...") snapshot, err := stackDeployer.CreateSnapshot(ctx) if err != nil { - log.Printf("Warning: failed to create snapshot: %v", err) - log.Println("Continuing without rollback capability") - return nil + return nil, fmt.Errorf("failed to create snapshot (rollback will not be available): %w", err) } - return snapshot + log.Println("Snapshot created successfully") + return snapshot, nil } // rollback restores the stack to a previous snapshot state