Skip to content

Commit 722ec52

Browse files
authored
feat: add support for DDUF/diffusers backend selection in model scheduler (#809)
1 parent 85a03ea commit 722ec52

3 files changed

Lines changed: 83 additions & 7 deletions

File tree

pkg/inference/scheduling/http_handler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ func (h *HTTPHandler) handleOpenAIInference(w http.ResponseWriter, r *http.Reque
206206
// Non-blocking call to track the model usage.
207207
h.scheduler.tracker.TrackModel(model, r.UserAgent(), action)
208208

209-
// Automatically identify models for vLLM.
209+
// Automatically select backend for given model.
210210
backend = h.scheduler.selectBackendForModel(model, backend, request.Model)
211211
}
212212

pkg/inference/scheduling/scheduler.go

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/docker/model-runner/pkg/distribution/types"
1212
"github.com/docker/model-runner/pkg/inference"
13+
"github.com/docker/model-runner/pkg/inference/backends/diffusers"
1314
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
1415
"github.com/docker/model-runner/pkg/inference/backends/mlx"
1516
"github.com/docker/model-runner/pkg/inference/backends/sglang"
@@ -30,6 +31,7 @@ type PlatformSupport interface {
3031
SupportsVLLM() bool
3132
SupportsVLLMMetal() bool
3233
SupportsSGLang() bool
34+
SupportsDiffusers() bool
3335
}
3436

3537
// defaultPlatformSupport delegates to the platform package.
@@ -39,6 +41,7 @@ func (defaultPlatformSupport) SupportsMLX() bool { return platform.Support
3941
func (defaultPlatformSupport) SupportsVLLM() bool { return platform.SupportsVLLM() }
4042
func (defaultPlatformSupport) SupportsVLLMMetal() bool { return platform.SupportsVLLMMetal() }
4143
func (defaultPlatformSupport) SupportsSGLang() bool { return platform.SupportsSGLang() }
44+
func (defaultPlatformSupport) SupportsDiffusers() bool { return platform.SupportsDiffusers() }
4245

4346
// Scheduler is used to coordinate inference scheduling across multiple backends
4447
// and models.
@@ -121,18 +124,18 @@ func (s *Scheduler) Run(ctx context.Context) error {
121124
}
122125

123126
// selectBackendForModel selects the appropriate backend for a model based on its format.
124-
// If the model is in safetensors format, it will prefer the best available backend:
125-
// - vLLM (handles platform dispatch internally: vllm-metal on macOS ARM64, standard vLLM on Linux)
126-
// - MLX on macOS
127-
// - SGLang on Linux
127+
// For safetensors models, it prefers: vLLM > MLX > SGLang.
128+
// For DDUF/diffusers models, it selects the diffusers backend.
129+
// For other formats (e.g. GGUF), it returns the provided default backend.
128130
func (s *Scheduler) selectBackendForModel(model types.Model, backend inference.Backend, modelRef string) inference.Backend {
129131
config, err := model.Config()
130132
if err != nil {
131133
s.log.Warn("failed to fetch model config", "error", err)
132134
return backend
133135
}
134136

135-
if config.GetFormat() == types.FormatSafetensors {
137+
switch config.GetFormat() {
138+
case types.FormatSafetensors:
136139
// Prefer vLLM for safetensors models (handles platform dispatch internally)
137140
if s.platformSupport.SupportsVLLM() || s.platformSupport.SupportsVLLMMetal() {
138141
if vllmBackend, ok := s.backends[vllm.Name]; ok && vllmBackend != nil {
@@ -151,8 +154,32 @@ func (s *Scheduler) selectBackendForModel(model types.Model, backend inference.B
151154
return sglangBackend
152155
}
153156
}
157+
backendName := "none"
158+
if backend != nil {
159+
backendName = backend.Name()
160+
}
154161
s.log.Warn("Model is in safetensors format but no compatible backend is available",
155-
"model", utils.SanitizeForLog(modelRef), "backend", backend.Name())
162+
"model", utils.SanitizeForLog(modelRef), "backend", backendName)
163+
164+
case types.FormatDDUF, types.FormatDiffusers: //nolint:staticcheck // FormatDiffusers kept for backward compatibility
165+
// Select the diffusers backend for DDUF and legacy diffusers format models
166+
if s.platformSupport.SupportsDiffusers() {
167+
if diffusersBackend, ok := s.backends[diffusers.Name]; ok && diffusersBackend != nil {
168+
return diffusersBackend
169+
}
170+
}
171+
backendName := "none"
172+
if backend != nil {
173+
backendName = backend.Name()
174+
}
175+
s.log.Warn("Model is in DDUF/diffusers format but no compatible backend is available",
176+
"model", utils.SanitizeForLog(modelRef), "backend", backendName)
177+
178+
case types.FormatGGUF:
179+
// GGUF models use the default backend (llamacpp)
180+
181+
default:
182+
// Unknown formats use the default backend
156183
}
157184

158185
return backend

pkg/inference/scheduling/select_backend_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/docker/model-runner/pkg/distribution/types"
88
"github.com/docker/model-runner/pkg/inference"
9+
"github.com/docker/model-runner/pkg/inference/backends/diffusers"
910
"github.com/docker/model-runner/pkg/inference/backends/mlx"
1011
"github.com/docker/model-runner/pkg/inference/backends/sglang"
1112
"github.com/docker/model-runner/pkg/inference/backends/vllm"
@@ -17,12 +18,14 @@ type mockPlatformSupport struct {
1718
vllm bool
1819
vllmMetal bool
1920
sglang bool
21+
diffusers bool
2022
}
2123

2224
func (m mockPlatformSupport) SupportsMLX() bool { return m.mlx }
2325
func (m mockPlatformSupport) SupportsVLLM() bool { return m.vllm }
2426
func (m mockPlatformSupport) SupportsVLLMMetal() bool { return m.vllmMetal }
2527
func (m mockPlatformSupport) SupportsSGLang() bool { return m.sglang }
28+
func (m mockPlatformSupport) SupportsDiffusers() bool { return m.diffusers }
2629

2730
// mockModel is a minimal Model implementation for testing.
2831
type mockModel struct {
@@ -55,9 +58,12 @@ func TestSelectBackendForModel(t *testing.T) {
5558
mlxBackend := &mockBackend{name: mlx.Name}
5659
vllmBackend := &mockBackend{name: vllm.Name}
5760
sglangBackend := &mockBackend{name: sglang.Name}
61+
diffusersBackend := &mockBackend{name: diffusers.Name}
5862

5963
safetensorsModel := &mockModel{config: &types.Config{Format: types.FormatSafetensors}}
6064
ggufModel := &mockModel{config: &types.Config{Format: types.FormatGGUF}}
65+
ddufModel := &mockModel{config: &types.Config{Format: types.FormatDDUF}}
66+
legacyDiffusersModel := &mockModel{config: &types.Config{Format: types.FormatDiffusers}} //nolint:staticcheck // testing backward compatibility
6167

6268
tests := []struct {
6369
name string
@@ -153,6 +159,49 @@ func TestSelectBackendForModel(t *testing.T) {
153159
model: safetensorsModel,
154160
expectedBackend: vllm.Name,
155161
},
162+
{
163+
name: "DDUF model selects diffusers backend when platform supports it",
164+
backends: map[string]inference.Backend{
165+
"llamacpp": llamacppBackend,
166+
diffusers.Name: diffusersBackend,
167+
},
168+
defaultBackend: llamacppBackend,
169+
platform: mockPlatformSupport{diffusers: true},
170+
model: ddufModel,
171+
expectedBackend: diffusers.Name,
172+
},
173+
{
174+
name: "DDUF model falls back to default when platform does not support diffusers",
175+
backends: map[string]inference.Backend{
176+
"llamacpp": llamacppBackend,
177+
diffusers.Name: diffusersBackend,
178+
},
179+
defaultBackend: llamacppBackend,
180+
platform: mockPlatformSupport{diffusers: false},
181+
model: ddufModel,
182+
expectedBackend: "llamacpp",
183+
},
184+
{
185+
name: "DDUF model falls back to default when diffusers backend not registered",
186+
backends: map[string]inference.Backend{
187+
"llamacpp": llamacppBackend,
188+
},
189+
defaultBackend: llamacppBackend,
190+
platform: mockPlatformSupport{diffusers: true},
191+
model: ddufModel,
192+
expectedBackend: "llamacpp",
193+
},
194+
{
195+
name: "legacy diffusers format model selects diffusers backend",
196+
backends: map[string]inference.Backend{
197+
"llamacpp": llamacppBackend,
198+
diffusers.Name: diffusersBackend,
199+
},
200+
defaultBackend: llamacppBackend,
201+
platform: mockPlatformSupport{diffusers: true},
202+
model: legacyDiffusersModel,
203+
expectedBackend: diffusers.Name,
204+
},
156205
}
157206

158207
for _, tt := range tests {

0 commit comments

Comments
 (0)