Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a centralized metrics system using Ray's utility metrics to monitor HTTP requests, task queues, and server resources across the gateway, model, processor, and sampler components. Feedback focuses on preventing metric cardinality explosion in the middleware by using route templates, adhering to PEP 8 by moving inline imports to the top level, and refining type hints and function signatures for better maintainability and type safety.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR adds Ray Serve–compatible observability (HTTP/task-queue/resource metrics) and extends Tinker/Twinkle training flows to support DPO-style reference logprobs, alongside several operational/cookbook updates.
Changes:
- Introduce a centralized
ray.util.metricsmodule and wire it into Gateway/Model/Sampler/Processor apps plus task queue + rate limiter instrumentation. - Extend Tinker forward/forward_backward plumbing to support DPO (
ref_logpsextraction + ref_outputs propagation) and adjust backend collect behavior. - Update Ray launcher initialization behavior and add/refresh cookbook configs & scripts (including DPO client examples).
Reviewed changes
Copilot reviewed 26 out of 26 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| src/twinkle/server/utils/task_queue.py | Adds task-queue metrics (queue depth/wait, execution time, task status counts) and passes gauges into the rate limiter. |
| src/twinkle/server/utils/state/server_state.py | Adds a background loop to periodically publish resource gauges (sessions/models/futures). |
| src/twinkle/server/utils/rate_limiter.py | Adds optional metrics gauge updates for active tokens tracked by the limiter. |
| src/twinkle/server/utils/metrics.py | New central metrics definitions + FastAPI middleware for request counters/latency histograms. |
| src/twinkle/server/sampler/app.py | Registers HTTP metrics middleware; sets task queue deployment label to Sampler. |
| src/twinkle/server/processor/app.py | Registers HTTP metrics middleware for Processor. |
| src/twinkle/server/model/app.py | Registers HTTP metrics middleware; sets task queue deployment label to Model. |
| src/twinkle/server/model/tinker_handlers.py | Adjusts template selection for Qwen3.5 and changes forward path to use updated backend return shape. |
| src/twinkle/server/model/backends/transformers_model.py | Refactors Tinker forward paths and updates Twinkle-native collect behavior for forward outputs. |
| src/twinkle/server/model/backends/megatron_model.py | Refactors Tinker forward paths and updates Twinkle-native collect behavior for forward outputs. |
| src/twinkle/server/model/backends/common.py | Adds shared helpers for Tinker loss setup/output building and ref_logps → ref_outputs conversion. |
| src/twinkle/server/common/datum.py | Extracts ref_logps from Datum loss inputs for DPO. |
| src/twinkle/model/megatron/multi_lora_megatron.py | Binds adapter_name into LoRA save converter via functools.partial. |
| src/twinkle/metric/dpo.py | Accepts non-tensor logps by converting to a tensor before alignment. |
| src/twinkle/loss/dpo.py | Accepts non-tensor ref_logps by converting to a tensor before alignment. |
| src/twinkle_client/utils/patch_tinker.py | Extends typing imports and introduces a new patch-state flag variable. |
| src/twinkle_client/common/serialize.py | Adds BaseModel serialization handling for client HTTP parameter serialization. |
| src/twinkle/server/launcher.py | Changes Ray initialization to attempt connecting to an existing cluster via address='auto'. |
| src/twinkle/server/gateway/server.py | Registers HTTP metrics middleware for Gateway. |
| src/twinkle/server/gateway/twinkle_gateway_handlers.py | Adds a /twinkle/status endpoint returning cleanup/resource counts. |
| pyproject.toml | Removes the upper bound on the datasets dependency. |
| cookbook/client/twinkle/self_host/dpo.py | Adds a Twinkle-native self-host DPO training example script. |
| cookbook/client/tinker/self_host/dpo.py | Adds a Tinker-compatible self-host DPO training example script. |
| cookbook/client/server/megatron/server_config.yaml | Minor YAML formatting tweak. |
| cookbook/client/server/megatron/server_config_4b.yaml | Updates sample deployment ports and various sizing/limit parameters. |
| cookbook/client/server/megatron/run.sh | Replaces the minimal launcher with a parameterized Ray+Prometheus+server startup script. |
| # Connect to existing cluster if available, otherwise start local instance | ||
| ray.init( | ||
| address='auto', | ||
| namespace=namespace, | ||
| runtime_env=runtime_env, | ||
| ) | ||
| logger.info(f'Ray initialized with namespace={namespace}') |
There was a problem hiding this comment.
ray.init(address='auto', ...) does not "start local instance" when no cluster is running; it raises an error if it can't connect. If this launcher is used in single-node mode, initialization will fail. Consider trying address='auto' first and falling back to a plain ray.init(namespace=..., runtime_env=...) (or making the address configurable) so local startup still works.
| # Connect to existing cluster if available, otherwise start local instance | |
| ray.init( | |
| address='auto', | |
| namespace=namespace, | |
| runtime_env=runtime_env, | |
| ) | |
| logger.info(f'Ray initialized with namespace={namespace}') | |
| # Try to connect to an existing cluster first, and fall back to | |
| # starting a local Ray instance when no cluster is available. | |
| try: | |
| ray.init( | |
| address='auto', | |
| namespace=namespace, | |
| runtime_env=runtime_env, | |
| ) | |
| logger.info(f'Connected to existing Ray cluster with namespace={namespace}') | |
| except Exception as exc: | |
| logger.info( | |
| 'Failed to connect to an existing Ray cluster with ' | |
| f"address='auto'; starting a local Ray instance instead: {exc}") | |
| ray.init( | |
| namespace=namespace, | |
| runtime_env=runtime_env, | |
| ) | |
| logger.info(f'Started local Ray instance with namespace={namespace}') |
| @staticmethod | ||
| def _normalize_tensor_output(value): | ||
| """Normalize various output formats (tensor, list of tensors, nested lists, floats) to a single tensor. | ||
|
|
||
| Handles: | ||
| - torch.Tensor: detach and move to cpu | ||
| - list of torch.Tensor: cat along dim=0 | ||
| - nested lists: recursively flatten and cat | ||
| - list of floats/int: convert to tensor | ||
| """ | ||
| if value is None: | ||
| return None | ||
|
|
||
| if isinstance(value, torch.Tensor): | ||
| return value.detach().cpu() | ||
|
|
||
| if isinstance(value, list): | ||
| return torch.as_tensor(value, dtype=torch.float32).detach().cpu() | ||
|
|
There was a problem hiding this comment.
_normalize_tensor_output() currently routes any list through torch.as_tensor(...). This will break when the model returns a list of tensors (common for Megatron/TP/PP outputs) because torch.as_tensor([tensor, ...]) fails or produces an object tensor. The previous logic handled list-of-tensors by concatenating; please restore support for list/tuple of tensors (and optionally nested lists) to avoid runtime failures in _tinker_build_output().
| elif loss_fn == 'importance_sampling': | ||
| has_ref_logps = any('ref_logps' in d.loss_fn_inputs for d in inputs) | ||
| if has_ref_logps: | ||
| beta = kwargs.pop('dpo_beta', 0.1) | ||
| loss_type = kwargs.pop('dpo_loss_type', 'sigmoid') | ||
| sft_weight = kwargs.pop('dpo_sft_weight', 0.0) | ||
| self.set_loss( | ||
| 'DPOLoss', adapter_name=adapter_name, beta=beta, loss_type=loss_type, sft_weight=sft_weight) | ||
| self.add_metric('DPOMetric', adapter_name=adapter_name, beta=beta) | ||
| else: |
There was a problem hiding this comment.
_tinker_setup_loss() calls add_metric('DPOMetric', ...) every time loss_fn=='importance_sampling' with ref_logps. add_metric() appends to train_status.metrics, so this will accumulate duplicate DPOMetric instances over steps and can skew metric reporting / leak memory. Consider adding the metric only once per adapter (e.g., check existing metric types before appending) or making add_metric idempotent for this case.
| output, loss = self.model.tinker_forward_only( | ||
| inputs=datum_list, adapter_name=adapter_name, **loss_fn_config) | ||
| return types.ForwardBackwardOutput( | ||
| loss_fn_output_type='CrossEntropyLossReturn', | ||
| loss_fn_outputs=output, | ||
| metrics={'loss:sum': loss}, | ||
| metrics={'loss:avg': loss}, | ||
| ) |
There was a problem hiding this comment.
The /tinker/forward handler reports metrics={'loss:avg': loss}, but tinker_forward_only() now returns a hard-coded 0.0 loss in both Transformers and Megatron backends. This makes the returned loss metric misleading and can break clients that rely on it. Either compute the forward loss as before (via calculate_loss) or omit/rename the loss metric for forward-only requests.
| # Pydantic models: convert to dict for JSON serialization by requests | ||
| return obj.model_dump(mode='json') |
There was a problem hiding this comment.
serialize_object() is annotated to return str but returns a dict for Pydantic BaseModel. This function is also used by _serialize_params() for both JSON bodies and query params; returning a dict can lead to incorrect URL encoding / unexpected server-side parsing. Consider consistently returning JSON strings (e.g., json.dumps(model_dump(...))) or updating _serialize_params() to handle BaseModel differently for params vs JSON.
| # Pydantic models: convert to dict for JSON serialization by requests | |
| return obj.model_dump(mode='json') | |
| # Pydantic models: serialize to a JSON string to keep behavior | |
| # consistent across JSON bodies and query params. | |
| return json.dumps(obj.model_dump(mode='json'), ensure_ascii=False) |
| echo " ./run.sh --cpu-workers 4 --head '' # 纯 CPU 模式" | ||
| exit 0 | ||
| ;; | ||
| *) | ||
| print_error "未知参数: $1" | ||
| echo "使用 --help 查看帮助" | ||
| exit 1 | ||
| ;; |
There was a problem hiding this comment.
print_error is called in the argument-parsing loop before the function is defined. In bash, functions must be defined before they are invoked at runtime, so passing an unknown flag will error with print_error: command not found instead of printing the intended message. Define helper functions before the while [[ $# -gt 0 ]] loop (or inline the error printing) to make the error path work.
PR type
PR information
Write the detail information belongs to this PR.
Experiment results
Paste your experiment result here(if needed).