feat: support concurrency in base sampling strategy#1175
Conversation
Signed-off-by: Jake LoRocco <jake.lorocco@ibm.com> Assisted-by: CLAUDE:OPUS
2a3dbc1 to
e72ebde
Compare
planetf1
left a comment
There was a problem hiding this comment.
A few cross-cutting observations alongside the line-level notes — mostly about behaviour shifts that users and plugin authors will hit but the diff itself doesn't surface. Happy to be told any of these are deliberate scope cuts for a follow-up.
Cost & rate limits. concurrency_budget=N multiplies the worst-case request count per sample() by N (and the expected count by some factor between 1 and N depending on success rate). For paid backends that's a real spend multiplier, and against rate-limited ones (Anthropic/OpenAI/Watsonx) we'll hit 429s sooner — possibly turning a previously-passing sample loop into a hard failure. The class docstring describes the mechanic but doesn't warn about either. Probably worth a one-line note in the Args: block, even if rate-limit handling is a separate piece of work.
Determinism. Two identical sample() calls with concurrency_budget>1 can return different winning slices depending on which subsample's network call returns first. Existing qualitative tests and example notebooks that assume sampling stability may start to flake. A docstring note saying "selection is non-deterministic when concurrency_budget>1" would set expectations.
with_context(sampling_iteration=...) under concurrency. contextvars is the right primitive for this (survives create_task cleanly), and a quick read suggests it'll be fine, but I didn't see a test that pins the behaviour — i.e. that when subsample A is on iteration 3 and subsample B is on iteration 7, plugin handlers see the right iteration number for the slice they're processing. If this ever bleeds, every SAMPLING_ITERATION-driven plugin is wrong under concurrency. A small assertion test would lock it down.
Cancellation cleanup. On early success the producers are cancelled mid-generate_from_context. Backend cleanup paths for in-flight HTTP / streaming response handles vary across providers. Not aware of an existing leak, but there's no test asserting "after early success, no in-flight backend calls remain". Could surface later as flaky test resource warnings under load.
Hook ordering. SAMPLING_ITERATION events now interleave across subsamples — the iteration field is globally unique (subsample_index * iterations + i + 1, nice), but consumers that assumed monotonic ordering will misbehave. Worth a one-line note on SamplingIterationPayload.
Docs & examples. Couldn't find a worked example of concurrency_budget in docs/examples/ or a callout in docs/AGENTS_TEMPLATE.md. Given how visible the speedup will be once people find it, an example showing "here's how to use it, here's the cost trade-off" would prevent both under-use and footgun-misuse.
Test gaps that fall out of the line-level points. Three the suite would benefit from:
- Backend exception under
concurrency_budget=1(covers the failure mode in the gather/_DONEcomment). MultiTurnStrategywithconcurrency_budget>1exhausting all attempts, asserting which sliceselect_from_failureactually picks (covers the ordering note).- Cancellation cleanup — assert no producer tasks remain pending after early success.
None of this is blocking from my side; mostly flagging for visibility so we don't ship the concurrency knob without the matching guardrails.
| t.cancel() # No-op if already done / cancelled. | ||
|
|
||
| # Wait for cancellations to settle so we don't leak tasks. | ||
| await asyncio.gather(*producer_tasks, return_exceptions=True) |
There was a problem hiding this comment.
One thing to flag on the cleanup path — if the backend raises inside generate_from_context, the exception gets absorbed by _producer's finally (which still puts _DONE) and then dropped by gather(return_exceptions=True). The consumer ends up with an empty slices and the user sees AssertionError: result index cannot be out of range from SamplingResult.__init__, with nothing pointing back at the real cause.
In the pre-PR sequential path the exception propagated directly. Wondering if it's worth re-raising the first non-cancellation exception from the gather when no slices made it through, just for the default concurrency_budget=1 case:
| await asyncio.gather(*producer_tasks, return_exceptions=True) | |
| _gathered = await asyncio.gather(*producer_tasks, return_exceptions=True) | |
| if not slices: | |
| _exc = next( | |
| (r for r in _gathered if isinstance(r, BaseException) and not isinstance(r, asyncio.CancelledError)), | |
| None, | |
| ) | |
| if _exc: | |
| raise RuntimeError("all sampling subsamples failed to produce a result") from _exc |
| @@ -187,205 +259,288 @@ async def sample( | |||
| ) | |||
| effective_loop_budget = start_payload.loop_budget | |||
There was a problem hiding this comment.
One thing to flag: a SAMPLING_LOOP_START hook can return loop_budget=0 (or negative), and the post-hook value is taken as-is. total_possible_generations then collapses to 0, no slices are produced, and we land on the same opaque AssertionError as the swallowed-exception case. The constructor assert validates self.loop_budget but the hook bypasses that.
| effective_loop_budget = start_payload.loop_budget | |
| effective_loop_budget = start_payload.loop_budget | |
| assert effective_loop_budget > 0, ( | |
| "SAMPLING_LOOP_START hook returned loop_budget <= 0; refusing to run." | |
| ) |
|
|
||
|
|
||
| # Module-level counter used by the "every 5th call passes" requirement below. | ||
| _validation_counter = 0 |
There was a problem hiding this comment.
The module-level _validation_counter survives across tests in the same worker, and the global reads make the dependency easy to miss when reading the test in isolation. Could be lifted into the test function as a nonlocal closure variable instead — the surrounding tests already use that pattern. Not load-bearing, just a small nudge for future-us.
| all_results=sampled_results, | ||
| all_validations=sampled_scores, | ||
| success=s_result.success, | ||
| iterations_used=len(slices), |
There was a problem hiding this comment.
iterations_used previously meant "how many sequential generate/validate cycles ran" — under concurrency it's now the total slice count across all subsamples (so up to loop_budget * concurrency_budget). Both readings are reasonable, but the field name and SamplingLoopEndPayload's docstring haven't moved with it, so an existing plugin will silently start seeing different numbers. Either a rename (slices_observed?) or a one-line note on the payload would probably be enough.
| if progress_indicator is not None: | ||
| progress_indicator.close() | ||
|
|
||
| s_result = _get_sampling_result( |
There was a problem hiding this comment.
slices arrives in queue order rather than per-subsample iteration order, so it's now interleaved across concurrent subsamples. Strategies whose select_from_failure returns -1 (e.g. MultiTurnStrategy) used to mean "the deepest-repaired turn"; with concurrency_budget>1 it's just whichever subsample finished last. Might be worth sorting by (subsample_index, iteration) here before handing off, or noting in MultiTurnStrategy.select_from_failure's docstring that the contract shifts under concurrency.
Pull Request
Issue
Fixes N/A; builds on previously closed PR (#240)
Description
Adds concurrency to our base sampling strategy. We lacked a way to concurrently sample and were only able to sample iteratively. The
sample()now manages distinct generators.subsample_iterationreplaces the actual sampling code that was previously insample().The total number of sampling iterations is now loop_budget * concurrency_budget. Concurrency budget is the breadth of the tree, loop_budget is the depth of the tree. At any time only concurrency_budget number of sampling iterations are running.
Testing
Attribution
Adding a new component, requirement, sampling strategy, or tool?
If your PR adds or modifies one of the types below, check the matching box. A checklist of type-specific review items will be posted as a comment.
NOTE: Please ensure you have an issue that has been acknowledged by a core contributor and routed you to open a pull request against this repository. Otherwise, please open an issue before continuing with this pull request.