feat(tinker): Add support for built-in loss functions and checkpoint control#523
Closed
feat(tinker): Add support for built-in loss functions and checkpoint control#523
Conversation
Collaborator
|
looks like some checks aren't passing. and have we tested this? |
Contributor
Author
No, I tested some of the more foundational pieces in earlier PRs (eg. the move from I can wait to merge until that run is working probably. (This will likely turn into a mega-PR with lots of unrelated changes along the way though as I find more pieces I need to update.) |
added 4 commits
January 21, 2026 19:22
…control Add two new features to TinkerBackend: 1. Built-in loss functions (tinker_loss_fn, tinker_loss_fn_config) - Supports Tinker's optimized losses: importance_sampling, ppo, cispo, dro - Uses forward_backward_async instead of forward_backward_custom_async - ~1.5x fewer FLOPs, up to 3x faster (per Tinker docs) - Default behavior unchanged (uses ART's custom loss) 2. Checkpoint control (save_checkpoint parameter) - When False, only saves sampler weights (fast, for inference) - When True (default), saves full state + optimizer (for resumption) - Enables faster training when full checkpoints only needed at intervals Both features are backwards-compatible - existing code works unchanged.
…ain() - Add optional adam_beta1, adam_beta2, adam_eps parameters to train() - Pass through to TinkerService via dev_config - Use params when calling optim_step_async with tinker.AdamParams This allows customization of Adam optimizer hyperparameters, which is needed when using non-default values (e.g., beta2=0.95 instead of 0.999).
Add adam_beta1, adam_beta2, and adam_eps to fix Pyright type errors when assigning these keys to the dev_config dict.
added 3 commits
January 22, 2026 00:33
- Update shift_tensor to support both 1D and 2D tensors - Replace NaN values in logprobs before JSON serialization to Tinker API - Guard Qwen3InstructRenderer patch for older tinker_cookbook versions
Previously, if port 8000 was already in use, the server would bind to a different port via get_free_port() but the client would still try to connect to port 8000, causing connection failures. Now the port is determined once upfront and passed to both the server and client.
Contributor
Author
|
replaced with #532 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds two new features to
TinkerBackendthat are fully backwards-compatible with existing code.1. Built-in Loss Functions
Adds support for Tinker's optimized built-in loss functions via new parameters on
TinkerBackend.train():tinker_loss_fn: Select from"importance_sampling","ppo","cispo","dro"tinker_loss_fn_config: Pass loss-specific config (e.g.,{"clip_low_threshold": 0.0, "clip_high_threshold": 6.0})Benefits:
forward_backward_asyncinstead offorward_backward_custom_asyncDefault behavior unchanged - when
tinker_loss_fn=None(default), continues to use ART's custom loss implementation.2. Checkpoint Control
The existing
save_checkpointparameter now controls checkpoint behavior in TinkerBackend:save_checkpoint=True(default): Saves full state + optimizer (enables training resumption)save_checkpoint=False: Only saves sampler weights (fast, for inference only)This enables faster training when full checkpoints are only needed at specific intervals (e.g., at eval steps).
Usage
Files Changed
src/art/dev/train.py: Addedtinker_loss_fn,tinker_loss_fn_config,tinker_save_checkpointto TrainConfigsrc/art/tinker/backend.py: Overrodetrain()with new parameterssrc/art/tinker/service.py: Added dispatch logic for built-in vs custom loss, added_save_sampler_weights_only()methodBackwards Compatibility
All existing code continues to work unchanged. The new parameters are optional with sensible defaults that preserve current behavior.