Skip to content

Commit

Permalink
feat: enable multiple containers for pipeline/model_id (#148)
Browse files Browse the repository at this point in the history
This commit makes the container map more unique providing users the case
of running multiple pipelines behind one external endpoint.

Co-authored-by: Rick Staa <[email protected]>
  • Loading branch information
ad-astra-video and rickstaa authored Aug 29, 2024
1 parent 8bb410c commit a02b3a9
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 63 deletions.
10 changes: 5 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ require (
github.com/perimeterx/marshmallow v1.1.5 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/net v0.25.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.20.0 // indirect
golang.org/x/tools v0.21.0 // indirect
golang.org/x/mod v0.20.0 // indirect
golang.org/x/net v0.28.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.23.0 // indirect
golang.org/x/tools v0.24.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools/v3 v3.5.1 // indirect
)
20 changes: 10 additions & 10 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -91,25 +91,25 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM=
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
Expand All @@ -118,8 +118,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
Expand Down
5 changes: 3 additions & 2 deletions worker/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const (

type RunnerContainer struct {
RunnerContainerConfig

Name string
Client *ClientWithResponses
}

Expand All @@ -39,7 +39,7 @@ type RunnerContainerConfig struct {
containerTimeout time.Duration
}

func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig) (*RunnerContainer, error) {
func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig, name string) (*RunnerContainer, error) {
// Ensure that timeout is set to a non-zero value.
timeout := cfg.containerTimeout
if timeout == 0 {
Expand Down Expand Up @@ -70,6 +70,7 @@ func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig) (*Runner

return &RunnerContainer{
RunnerContainerConfig: cfg,
Name: name,
Client: client,
}, nil
}
Expand Down
63 changes: 35 additions & 28 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package worker
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"sync"
Expand Down Expand Up @@ -31,10 +32,10 @@ const containerCreator = "ai-worker"
// using the GPU we stop it so we don't have to worry about having enough ports
var containerHostPorts = map[string]string{
"text-to-image": "8000",
"image-to-image": "8001",
"image-to-video": "8002",
"upscale": "8003",
"audio-to-text": "8004",
"image-to-image": "8100",
"image-to-video": "8200",
"upscale": "8300",
"audio-to-text": "8400",
}

type DockerManager struct {
Expand Down Expand Up @@ -109,40 +110,42 @@ func (m *DockerManager) Borrow(ctx context.Context, pipeline, modelID string) (*
m.mu.Lock()
defer m.mu.Unlock()

containerName := dockerContainerName(pipeline, modelID)
rc, ok := m.containers[containerName]
if !ok {
// The container does not exist so try to create it
var err error
// TODO: Optimization flags for dynamically loaded (borrowed) containers are not currently supported due to startup delays.
rc, err = m.createContainer(ctx, pipeline, modelID, false, map[string]EnvValue{})
if err != nil {
return nil, err
for _, runner := range m.containers {
if runner.Pipeline == pipeline && runner.ModelID == modelID {
delete(m.containers, runner.Name)
return runner, nil
}
}

// The container does not exist so try to create it
var err error
// TODO: Optimization flags for dynamically loaded (borrowed) containers are not currently supported due to startup delays.
rc, err := m.createContainer(ctx, pipeline, modelID, false, map[string]EnvValue{})
if err != nil {
return nil, err
}

// Remove container so it is unavailable until Return() is called
delete(m.containers, containerName)
delete(m.containers, rc.Name)
return rc, nil
}

func (m *DockerManager) Return(rc *RunnerContainer) {
m.mu.Lock()
defer m.mu.Unlock()
m.containers[dockerContainerName(rc.Pipeline, rc.ModelID)] = rc
m.containers[rc.Name] = rc
}

// HasCapacity checks if an unused managed container exists or if a GPU is available for a new container.
func (m *DockerManager) HasCapacity(ctx context.Context, pipeline, modelID string) bool {
containerName := dockerContainerName(pipeline, modelID)

m.mu.Lock()
defer m.mu.Unlock()

// Check if unused managed container exists for the requested model.
_, ok := m.containers[containerName]
if ok {
return true
for _, rc := range m.containers {
if rc.Pipeline == pipeline && rc.ModelID == modelID {
return true
}
}

// Check for available GPU to allocate for a new container for the requested model.
Expand All @@ -151,13 +154,15 @@ func (m *DockerManager) HasCapacity(ctx context.Context, pipeline, modelID strin
}

func (m *DockerManager) createContainer(ctx context.Context, pipeline string, modelID string, keepWarm bool, optimizationFlags OptimizationFlags) (*RunnerContainer, error) {
containerName := dockerContainerName(pipeline, modelID)

gpu, err := m.allocGPU(ctx)
if err != nil {
return nil, err
}

// NOTE: We currently allow only one container per GPU for each pipeline.
containerHostPort := containerHostPorts[pipeline][:3] + gpu
containerName := dockerContainerName(pipeline, modelID, containerHostPort)

slog.Info("Starting managed container", slog.String("gpu", gpu), slog.String("name", containerName), slog.String("modelID", modelID))

// Add optimization flags as environment variables.
Expand Down Expand Up @@ -186,7 +191,6 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo
gpuOpts := opts.GpuOpts{}
gpuOpts.Set("device=" + gpu)

containerHostPort := containerHostPorts[pipeline]
hostConfig := &container.HostConfig{
Resources: container.Resources{
DeviceRequests: gpuOpts.Value(),
Expand Down Expand Up @@ -249,7 +253,7 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo
containerTimeout: runnerContainerTimeout,
}

rc, err := NewRunnerContainer(ctx, cfg)
rc, err := NewRunnerContainer(ctx, cfg, containerName)
if err != nil {
dockerRemoveContainer(m.dockerClient, resp.ID)
return nil, err
Expand Down Expand Up @@ -309,10 +313,13 @@ func removeExistingContainers(ctx context.Context, client *client.Client) error
return nil
}

func dockerContainerName(pipeline string, modelID string) string {
// text-to-image, stabilityai/sd-turbo -> text-to-image_stabilityai_sd-turbo
// image-to-video, stabilityai/stable-video-diffusion-img2vid-xt -> image-to-video_stabilityai_stable-video-diffusion-img2vid-xt
return strings.ReplaceAll(pipeline+"_"+modelID, "/", "_")
// dockerContainerName generates a unique container name based on the pipeline, model ID, and an optional suffix.
func dockerContainerName(pipeline string, modelID string, suffix ...string) string {
sanitizedModelID := strings.NewReplacer("/", "-", "_", "-").Replace(modelID)
if len(suffix) > 0 {
return fmt.Sprintf("%s_%s_%s", pipeline, sanitizedModelID, suffix[0])
}
return fmt.Sprintf("%s_%s", pipeline, sanitizedModelID)
}

func dockerRemoveContainer(client *client.Client, containerID string) error {
Expand Down
35 changes: 17 additions & 18 deletions worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,12 @@ func (w *Worker) Warm(ctx context.Context, pipeline string, modelID string, endp
Endpoint: endpoint,
containerTimeout: externalContainerTimeout,
}
rc, err := NewRunnerContainer(ctx, cfg)
rc, err := NewRunnerContainer(ctx, cfg, endpoint.URL)
if err != nil {
return err
}

name := dockerContainerName(pipeline, modelID)
name := dockerContainerName(pipeline, modelID, endpoint.URL)
slog.Info("Starting external container", slog.String("name", name), slog.String("modelID", modelID))
w.externalContainers[name] = rc

Expand All @@ -348,30 +348,29 @@ func (w *Worker) Stop(ctx context.Context) error {

// HasCapacity returns true if the worker has capacity for the given pipeline and model ID.
func (w *Worker) HasCapacity(pipeline, modelID string) bool {
managedCapacity := w.manager.HasCapacity(context.Background(), pipeline, modelID)
if managedCapacity {
return true
}

// Check if we have capacity for external containers.
name := dockerContainerName(pipeline, modelID)
w.mu.Lock()
defer w.mu.Unlock()
_, ok := w.externalContainers[name]

return ok
// Check if we have capacity for external containers.
for _, rc := range w.externalContainers {
if rc.Pipeline == pipeline && rc.ModelID == modelID {
return true
}
}

// Check if we have capacity for managed containers.
return w.manager.HasCapacity(context.Background(), pipeline, modelID)
}

func (w *Worker) borrowContainer(ctx context.Context, pipeline, modelID string) (*RunnerContainer, error) {
w.mu.Lock()

name := dockerContainerName(pipeline, modelID)
rc, ok := w.externalContainers[name]
if ok {
w.mu.Unlock()
// We allow concurrent in-flight requests for external containers and assume that it knows
// how to handle them
return rc, nil
for _, rc := range w.externalContainers {
if rc.Pipeline == pipeline && rc.ModelID == modelID {
w.mu.Unlock()
// Assume external containers can handle concurrent in-flight requests.
return rc, nil
}
}

w.mu.Unlock()
Expand Down

0 comments on commit a02b3a9

Please sign in to comment.