diff --git a/ios/tunnel/tunnel_api.go b/ios/tunnel/tunnel_api.go index db98f026..51f8f52c 100644 --- a/ios/tunnel/tunnel_api.go +++ b/ios/tunnel/tunnel_api.go @@ -270,7 +270,7 @@ func (m *TunnelManager) FirstUpdateCompleted() bool { // UpdateTunnels checks for connected devices and starts a new tunnel if needed // On device disconnects the tunnel resources get cleaned up -func (m *TunnelManager) UpdateTunnels(ctx context.Context) error { +func (m *TunnelManager) UpdateTunnels(ctx context.Context, excludedDevices map[string]struct{}) error { m.mux.Lock() localTunnels := map[string]Tunnel{} @@ -283,6 +283,11 @@ func (m *TunnelManager) UpdateTunnels(ctx context.Context) error { } for _, d := range devices.DeviceList { udid := d.Properties.SerialNumber + if excludedDevices != nil { + if _, excluded := excludedDevices[udid]; excluded { + continue + } + } if _, exists := localTunnels[udid]; exists { continue } diff --git a/ios/tunnel/tunnel_api_test.go b/ios/tunnel/tunnel_api_test.go index 2140341b..0c109096 100644 --- a/ios/tunnel/tunnel_api_test.go +++ b/ios/tunnel/tunnel_api_test.go @@ -40,7 +40,7 @@ func TestSuccessStartForMultipleConnectedDevices(t *testing.T) { Udid: "serial2", }, nil) - err := tm.UpdateTunnels(context.Background()) + err := tm.UpdateTunnels(context.Background(), nil) assert.NoError(t, err) tunnels, err := tm.ListTunnels() @@ -80,7 +80,7 @@ func TestCloseTunnelsOnDisconnect(t *testing.T) { closer: closer, }, nil) - err := tm.UpdateTunnels(context.Background()) + err := tm.UpdateTunnels(context.Background(), nil) assert.NoError(t, err) tunnels, _ := tm.ListTunnels() @@ -90,7 +90,7 @@ func TestCloseTunnelsOnDisconnect(t *testing.T) { Return(ios.DeviceList{}, nil). Once() - err = tm.UpdateTunnels(context.Background()) + err = tm.UpdateTunnels(context.Background(), nil) assert.NoError(t, err) tunnels, _ = tm.ListTunnels() assert.Len(t, tunnels, 0) @@ -115,9 +115,9 @@ func TestBridgeIsOnlyStarteOnce(t *testing.T) { closer: closer, }, nil) - err := tm.UpdateTunnels(context.Background()) + err := tm.UpdateTunnels(context.Background(), nil) assert.NoError(t, err) - err = tm.UpdateTunnels(context.Background()) + err = tm.UpdateTunnels(context.Background(), nil) assert.NoError(t, err) ts.AssertNumberOfCalls(t, "StartTunnel", 1) diff --git a/main.go b/main.go index 595f478a..cb78dc17 100644 --- a/main.go +++ b/main.go @@ -2334,7 +2334,7 @@ func startTunnel(ctx context.Context, recordsPath string, tunnelInfoPort int, us case <-ctx.Done(): return case <-ticker.C: - err := tm.UpdateTunnels(ctx) + err := tm.UpdateTunnels(ctx, nil) if err != nil { log.WithError(err).Warn("failed to update tunnels") }