diff --git a/detector_test.go b/detector_test.go index bef8825..43909c8 100644 --- a/detector_test.go +++ b/detector_test.go @@ -604,6 +604,71 @@ func TestSignalFork(t *testing.T) { assert.Equal(t, ProcessExitEvent.String(), exitEvent.EventType.String(), "third event should be exit") } +func TestTrackProcessesBeforeRun(t *testing.T) { + require.NotEmpty(t, sleepLocation, "sleep must be installed for the test") + + // start a process before starting the detector + cmd := exec.Command(sleepLocation, "30") + require.NoError(t, cmd.Start()) + proc := &testProcess{cmd: cmd, pid: cmd.Process.Pid} + defer proc.stop() + + events := make(chan ProcessEvent, 100) + d, err := NewDetector(events, + // set no duration filter + WithMinDuration(0), + // our test process doesn't have this env var, + // but we want to make sure TrackProcesses causes the detector to emit an exit event for the tracked pid + // even if the process doesn't match the environment variable filters + WithEnvPrefixFilter("USER_E"), + ) + require.NoError(t, err) + + // the detector should handle TrackProcesses being called before Run + require.NotPanics(t, func() { + err := d.TrackProcesses([]int{proc.pid}) + require.NoError(t, err) + }, "TrackProcesses should not panic when called before Run") + + // run the detector and kill the target process before it exits + // we expect to receive an exit event + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + runDone := make(chan struct{}) + go func() { + defer close(runDone) + err := d.Run(ctx) + require.NoError(t, err) + }() + + time.Sleep(500 * time.Millisecond) + proc.stop() + + var gotExit bool + deadline := time.After(2 * time.Second) +collect: + for { + select { + case e, ok := <-events: + if !ok { + break collect + } + if e.PID == proc.pid && e.EventType == ProcessExitEvent { + gotExit = true + break collect + } + case <-deadline: + break collect + } + } + + cancel() + <-runDone + + require.True(t, gotExit, "expected exit event for tracked pid %d registered before Run", proc.pid) +} + func envVarsToSlice(envVars map[string]string) []string { var result []string for k, v := range envVars { diff --git a/internal/probe/probe.go b/internal/probe/probe.go index a0bdd5f..99aa068 100644 --- a/internal/probe/probe.go +++ b/internal/probe/probe.go @@ -9,6 +9,7 @@ import ( "math" "os" "runtime" + "sync" "github.com/cilium/ebpf" "github.com/cilium/ebpf/features" @@ -39,6 +40,12 @@ type Probe struct { openFilesToTrack []string execFilesToFilter map[string]struct{} btfDisabled bool + + // pendingPIDsToTrack holds the PIDs that should be tracked by the probe, + // since TrackPIDs can be called before the probe is fully loaded and the eBPF collection is not available yet. + // Once the probe is loaded, these PIDs will be written to the eBPF map and tracked properly. + pendingPIDsToTrack []int + mu sync.Mutex } type processEvent struct { @@ -119,6 +126,20 @@ func (p *Probe) LoadAndAttach() error { return fmt.Errorf("can't attach probe: %w", err) } + // TrackPIDs might be called before the probe is fully loaded, + // in that case we need to write the pending PIDs to the map now that the probe is loaded + p.mu.Lock() + pending := p.pendingPIDsToTrack + p.pendingPIDsToTrack = nil + p.mu.Unlock() + if len(pending) == 0 { + return nil + } + err = p.writePIDsToMap(pending) + if err != nil { + return fmt.Errorf("can't write pending PIDs to map: %w", err) + } + return nil } @@ -410,6 +431,25 @@ func (p *Probe) GetContainerPID(pid int) (int, error) { } func (p *Probe) TrackPIDs(pids []int) error { + p.mu.Lock() + if p.c == nil { + // probe is not fully loaded yet, + // store the PIDs in the pending list to be written to the map once the probe is loaded + p.pendingPIDsToTrack = append(p.pendingPIDsToTrack, pids...) + p.mu.Unlock() + return nil + } + p.mu.Unlock() + + // probe is loaded, write the PIDs to the map to be tracked by the probe + return p.writePIDsToMap(pids) +} + +func (p *Probe) writePIDsToMap(pids []int) error { + if p.c == nil { + return errors.New("eBPF collection is not loaded") + } + m, ok := p.c.Maps[pidToContainerPIDMapName] if !ok { return errors.New("eBPF maps are not loaded")