Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 162 additions & 5 deletions server/cmd/api/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,20 @@ import (
oapi "github.com/onkernel/kernel-images/server/lib/oapi"
"github.com/onkernel/kernel-images/server/lib/recorder"
"github.com/onkernel/kernel-images/server/lib/scaletozero"
"github.com/onkernel/kernel-images/server/lib/stream"
)

type ApiService struct {
// defaultRecorderID is used whenever the caller doesn't specify an explicit ID.
defaultRecorderID string

recordManager recorder.RecordManager
factory recorder.FFmpegRecorderFactory
defaultStreamID string

recordManager recorder.RecordManager
factory recorder.FFmpegRecorderFactory
streamManager stream.Manager
streamFactory stream.FFmpegStreamerFactory
rtmpServer stream.InternalServer
streamDefaults stream.Params
// Filesystem watch management
watchMu sync.RWMutex
watches map[string]*fsWatch
Expand All @@ -46,7 +52,7 @@ type ApiService struct {

var _ oapi.StrictServerInterface = (*ApiService)(nil)

func New(recordManager recorder.RecordManager, factory recorder.FFmpegRecorderFactory, upstreamMgr *devtoolsproxy.UpstreamManager, stz scaletozero.Controller, nekoAuthClient *nekoclient.AuthClient) (*ApiService, error) {
func New(recordManager recorder.RecordManager, factory recorder.FFmpegRecorderFactory, upstreamMgr *devtoolsproxy.UpstreamManager, stz scaletozero.Controller, nekoAuthClient *nekoclient.AuthClient, streamManager stream.Manager, streamFactory stream.FFmpegStreamerFactory, rtmpServer stream.InternalServer, streamDefaults stream.Params) (*ApiService, error) {
switch {
case recordManager == nil:
return nil, fmt.Errorf("recordManager cannot be nil")
Expand All @@ -56,12 +62,26 @@ func New(recordManager recorder.RecordManager, factory recorder.FFmpegRecorderFa
return nil, fmt.Errorf("upstreamMgr cannot be nil")
case nekoAuthClient == nil:
return nil, fmt.Errorf("nekoAuthClient cannot be nil")
case streamManager == nil:
return nil, fmt.Errorf("streamManager cannot be nil")
case streamFactory == nil:
return nil, fmt.Errorf("streamFactory cannot be nil")
case rtmpServer == nil:
return nil, fmt.Errorf("rtmpServer cannot be nil")
}
if streamDefaults.FrameRate == nil || streamDefaults.DisplayNum == nil {
return nil, fmt.Errorf("streamDefaults must include frame rate and display number")
}

return &ApiService{
recordManager: recordManager,
factory: factory,
defaultRecorderID: "default",
streamManager: streamManager,
streamFactory: streamFactory,
rtmpServer: rtmpServer,
streamDefaults: streamDefaults,
defaultStreamID: "default",
watches: make(map[string]*fsWatch),
procs: make(map[string]*processHandle),
upstreamMgr: upstreamMgr,
Expand Down Expand Up @@ -236,6 +256,129 @@ func (s *ApiService) DeleteRecording(ctx context.Context, req oapi.DeleteRecordi
return oapi.DeleteRecording200Response{}, nil
}

func (s *ApiService) StartStream(ctx context.Context, req oapi.StartStreamRequestObject) (oapi.StartStreamResponseObject, error) {
log := logger.FromContext(ctx)

if req.Body == nil {
return oapi.StartStream400JSONResponse{BadRequestErrorJSONResponse: oapi.BadRequestErrorJSONResponse{Message: "request body required"}}, nil
}

streamID := s.defaultStreamID
if req.Body.Id != nil && *req.Body.Id != "" {
streamID = *req.Body.Id
}

mode := stream.ModeInternal
if req.Body.Mode != nil && *req.Body.Mode != "" {
mode = stream.Mode(*req.Body.Mode)
}
if mode != stream.ModeInternal && mode != stream.ModeRemote {
return oapi.StartStream400JSONResponse{BadRequestErrorJSONResponse: oapi.BadRequestErrorJSONResponse{Message: "invalid stream mode"}}, nil
}

frameRate := s.streamDefaults.FrameRate
if req.Body.Framerate != nil {
frameRate = req.Body.Framerate
}

if existing, ok := s.streamManager.GetStream(streamID); ok {
if existing.IsStreaming(ctx) {
return oapi.StartStream409JSONResponse{ConflictErrorJSONResponse: oapi.ConflictErrorJSONResponse{Message: "stream already in progress"}}, nil
}
_ = s.streamManager.DeregisterStream(ctx, existing)
}

var ingestURL string
var playbackURL *string
var securePlaybackURL *string
streamPath := fmt.Sprintf("live/%s", streamID)

switch mode {
case stream.ModeInternal:
if err := s.rtmpServer.Start(ctx); err != nil {
log.Error("failed to start internal rtmp server", "err", err)
return oapi.StartStream500JSONResponse{InternalErrorJSONResponse: oapi.InternalErrorJSONResponse{Message: "failed to start internal streaming server"}}, nil
}
s.rtmpServer.EnsureStream(streamPath)
ingestURL = s.rtmpServer.IngestURL(streamPath)
playbackURL, securePlaybackURL = s.rtmpServer.PlaybackURLs("", streamPath)
case stream.ModeRemote:
if req.Body.TargetUrl == nil || *req.Body.TargetUrl == "" {
return oapi.StartStream400JSONResponse{BadRequestErrorJSONResponse: oapi.BadRequestErrorJSONResponse{Message: "target_url is required for remote streaming"}}, nil
}
ingestURL = *req.Body.TargetUrl
playbackURL = &ingestURL
}

params := stream.Params{
FrameRate: frameRate,
DisplayNum: s.streamDefaults.DisplayNum,
Mode: mode,
IngestURL: ingestURL,
PlaybackURL: playbackURL,
SecurePlaybackURL: securePlaybackURL,
}

streamer, err := s.streamFactory(streamID, params)
if err != nil {
log.Error("failed to create streamer", "err", err, "stream_id", streamID)
return oapi.StartStream500JSONResponse{InternalErrorJSONResponse: oapi.InternalErrorJSONResponse{Message: "failed to create streamer"}}, nil
}
if err := s.streamManager.RegisterStream(ctx, streamer); err != nil {
log.Error("failed to register stream", "err", err, "stream_id", streamID)
return oapi.StartStream409JSONResponse{ConflictErrorJSONResponse: oapi.ConflictErrorJSONResponse{Message: "stream already exists"}}, nil
}
if err := streamer.Start(ctx); err != nil {
log.Error("failed to start stream", "err", err, "stream_id", streamID)
_ = s.streamManager.DeregisterStream(ctx, streamer)
return oapi.StartStream500JSONResponse{InternalErrorJSONResponse: oapi.InternalErrorJSONResponse{Message: "failed to start stream"}}, nil
}

return oapi.StartStream201JSONResponse(streamMetadataToOAPI(streamer.Metadata(), streamer.IsStreaming(ctx))), nil
}

func (s *ApiService) StopStream(ctx context.Context, req oapi.StopStreamRequestObject) (oapi.StopStreamResponseObject, error) {
log := logger.FromContext(ctx)

streamID := s.defaultStreamID
if req.Body != nil && req.Body.Id != nil && *req.Body.Id != "" {
streamID = *req.Body.Id
}

streamer, ok := s.streamManager.GetStream(streamID)
if !ok {
return oapi.StopStream404JSONResponse{NotFoundErrorJSONResponse: oapi.NotFoundErrorJSONResponse{Message: "stream not found"}}, nil
}

if err := streamer.Stop(ctx); err != nil {
log.Error("failed to stop stream", "err", err, "stream_id", streamID)
}
_ = s.streamManager.DeregisterStream(ctx, streamer)

return oapi.StopStream200Response{}, nil
}

func (s *ApiService) ListStreams(ctx context.Context, _ oapi.ListStreamsRequestObject) (oapi.ListStreamsResponseObject, error) {
streams := s.streamManager.ListStreams(ctx)
infos := make([]oapi.StreamInfo, 0, len(streams))
for _, st := range streams {
infos = append(infos, streamMetadataToOAPI(st.Metadata(), st.IsStreaming(ctx)))
}
return oapi.ListStreams200JSONResponse(infos), nil
}

func streamMetadataToOAPI(meta stream.Metadata, isStreaming bool) oapi.StreamInfo {
return oapi.StreamInfo{
Id: meta.ID,
Mode: oapi.StreamInfoMode(meta.Mode),
IngestUrl: meta.IngestURL,
PlaybackUrl: meta.PlaybackURL,
SecurePlaybackUrl: meta.SecurePlaybackURL,
StartedAt: meta.StartedAt,
IsStreaming: isStreaming,
}
}

// ListRecorders returns a list of all registered recorders and whether each one is currently recording.
func (s *ApiService) ListRecorders(ctx context.Context, _ oapi.ListRecordersRequestObject) (oapi.ListRecordersResponseObject, error) {
infos := []oapi.RecorderInfo{}
Expand All @@ -261,5 +404,19 @@ func (s *ApiService) ListRecorders(ctx context.Context, _ oapi.ListRecordersRequ
}

func (s *ApiService) Shutdown(ctx context.Context) error {
return s.recordManager.StopAll(ctx)
var errs []error
if err := s.recordManager.StopAll(ctx); err != nil {
errs = append(errs, err)
}
if err := s.streamManager.StopAll(ctx); err != nil {
errs = append(errs, err)
}
if err := s.rtmpServer.Close(ctx); err != nil {
errs = append(errs, err)
}

if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
Loading
Loading