Skip to content

Commit 0552978

Browse files
committed
Add /set command and context length configuration
Add a /set command to the interactive mode to allow users to configure parameters like num_ctx during runtime. Also implement environment variable support for default context length (DMR_CONTEXT_LENGTH) and ensure the scheduler respects these configurations when setting up backend runners. Signed-off-by: Eric Curtin <eric.curtin@docker.com>
1 parent 4d4402b commit 0552978

4 files changed

Lines changed: 67 additions & 10 deletions

File tree

cmd/cli/commands/run.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@ import (
88
"io"
99
"os"
1010
"os/signal"
11+
"strconv"
1112
"strings"
1213
"syscall"
1314

1415
"github.com/charmbracelet/glamour"
1516
"github.com/docker/model-runner/cmd/cli/commands/completion"
1617
"github.com/docker/model-runner/cmd/cli/desktop"
1718
"github.com/docker/model-runner/cmd/cli/readline"
19+
"github.com/docker/model-runner/pkg/inference"
20+
"github.com/docker/model-runner/pkg/inference/scheduling"
1821
"github.com/fatih/color"
1922
"github.com/muesli/termenv"
2023
"github.com/spf13/cobra"
@@ -92,6 +95,7 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.
9295
fmt.Fprintln(os.Stderr, "Available Commands:")
9396
fmt.Fprintln(os.Stderr, " /set system Set or update the system message")
9497
fmt.Fprintln(os.Stderr, " /bye Exit")
98+
fmt.Fprintln(os.Stderr, " /set Set a parameter (e.g., /set parameter num_ctx 4096)")
9599
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
96100
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
97101
fmt.Fprintln(os.Stderr, " /? files Help for file inclusion with @ symbol")
@@ -231,6 +235,36 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.
231235
continue
232236
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
233237
return nil
238+
case strings.HasPrefix(line, "/set"):
239+
args := strings.Fields(line)
240+
if len(args) < 4 || args[1] != "parameter" {
241+
fmt.Fprintln(os.Stderr, "Usage: /set parameter <name> <value>")
242+
fmt.Fprintln(os.Stderr, "Available parameters: num_ctx")
243+
continue
244+
}
245+
paramName, paramValue := args[2], args[3]
246+
switch paramName {
247+
case "num_ctx":
248+
if val, err := strconv.ParseInt(paramValue, 10, 32); err == nil && val > 0 {
249+
ctx := int32(val)
250+
if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{
251+
Model: model,
252+
BackendConfiguration: inference.BackendConfiguration{
253+
ContextSize: &ctx,
254+
},
255+
}); err != nil {
256+
fmt.Fprintf(os.Stderr, "Failed to set num_ctx: %v\n", err)
257+
} else {
258+
fmt.Fprintf(os.Stderr, "Set num_ctx to %d\n", val)
259+
}
260+
} else {
261+
fmt.Fprintf(os.Stderr, "Invalid value for num_ctx: %s (must be a positive integer)\n", paramValue)
262+
}
263+
default:
264+
fmt.Fprintf(os.Stderr, "Unknown parameter: %s\n", paramName)
265+
fmt.Fprintln(os.Stderr, "Available parameters: num_ctx")
266+
}
267+
continue
234268
case strings.HasPrefix(line, "/"):
235269
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
236270
continue

main.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"os"
88
"os/signal"
99
"path/filepath"
10+
"strconv"
1011
"strings"
1112
"syscall"
1213
"time"
@@ -77,6 +78,18 @@ func main() {
7778
sglangServerPath := os.Getenv("SGLANG_SERVER_PATH")
7879
mlxServerPath := os.Getenv("MLX_SERVER_PATH")
7980

81+
// Parse default context length from environment
82+
var defaultContextLength *int32
83+
if ctxStr := os.Getenv("DMR_CONTEXT_LENGTH"); ctxStr != "" {
84+
if parsed, err := strconv.ParseInt(ctxStr, 10, 32); err == nil && parsed > 0 {
85+
ctx := int32(parsed)
86+
defaultContextLength = &ctx
87+
log.Infof("DMR_CONTEXT_LENGTH: %d", ctx)
88+
} else {
89+
log.Warnf("Invalid DMR_CONTEXT_LENGTH: %s (must be a positive integer)", ctxStr)
90+
}
91+
}
92+
8093
// Create a proxy-aware HTTP transport
8194
// Use a safe type assertion with fallback, and explicitly set Proxy to http.ProxyFromEnvironment
8295
var baseTransport *http.Transport
@@ -175,6 +188,7 @@ func main() {
175188
"",
176189
false,
177190
),
191+
defaultContextLength,
178192
)
179193

180194
// Create the HTTP handler for the scheduler

pkg/inference/scheduling/scheduler.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ type Scheduler struct {
4040
tracker *metrics.Tracker
4141
// openAIRecorder is used to record OpenAI API inference requests and responses.
4242
openAIRecorder *metrics.OpenAIRecorder
43+
// defaultContextLength is the default context length from environment variable.
44+
defaultContextLength *int32
4345
}
4446

4547
// NewScheduler creates a new inference scheduler.
@@ -50,19 +52,21 @@ func NewScheduler(
5052
modelManager *models.Manager,
5153
httpClient *http.Client,
5254
tracker *metrics.Tracker,
55+
defaultContextLength *int32,
5356
) *Scheduler {
5457
openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"), modelManager)
5558

5659
// Create the scheduler.
5760
s := &Scheduler{
58-
log: log,
59-
backends: backends,
60-
defaultBackend: defaultBackend,
61-
modelManager: modelManager,
62-
installer: newInstaller(log, backends, httpClient),
63-
loader: newLoader(log, backends, modelManager, openAIRecorder),
64-
tracker: tracker,
65-
openAIRecorder: openAIRecorder,
61+
log: log,
62+
backends: backends,
63+
defaultBackend: defaultBackend,
64+
modelManager: modelManager,
65+
installer: newInstaller(log, backends, httpClient),
66+
loader: newLoader(log, backends, modelManager, openAIRecorder),
67+
tracker: tracker,
68+
openAIRecorder: openAIRecorder,
69+
defaultContextLength: defaultContextLength,
6670
}
6771

6872
// Scheduler successfully initialized.
@@ -253,7 +257,12 @@ func (s *Scheduler) ConfigureRunner(ctx context.Context, backend inference.Backe
253257

254258
// Build runner configuration with shared settings
255259
var runnerConfig inference.BackendConfiguration
256-
runnerConfig.ContextSize = req.ContextSize
260+
// Use request context size if provided, otherwise fall back to default from env var
261+
if req.ContextSize != nil {
262+
runnerConfig.ContextSize = req.ContextSize
263+
} else if s.defaultContextLength != nil {
264+
runnerConfig.ContextSize = s.defaultContextLength
265+
}
257266
runnerConfig.Speculative = req.Speculative
258267
runnerConfig.RuntimeFlags = runtimeFlags
259268

pkg/inference/scheduling/scheduler_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func TestCors(t *testing.T) {
3333
discard := logrus.New()
3434
discard.SetOutput(io.Discard)
3535
log := logrus.NewEntry(discard)
36-
s := NewScheduler(log, nil, nil, nil, nil, nil)
36+
s := NewScheduler(log, nil, nil, nil, nil, nil, nil)
3737
httpHandler := NewHTTPHandler(s, nil, []string{"*"})
3838
req := httptest.NewRequest(http.MethodOptions, "http://model-runner.docker.internal"+tt.path, http.NoBody)
3939
req.Header.Set("Origin", "docker.com")

0 commit comments

Comments
 (0)