From 8ad7dd5f9d68e2df0390601c88ad046953a6db53 Mon Sep 17 00:00:00 2001 From: Annie Liang Date: Fri, 5 Jun 2026 09:31:43 -0700 Subject: [PATCH 1/4] fix(cosmos-spark): validate bounded change feed EOF against planned endLsn When TransientIOErrorsRetryingIterator reaches EOF on a bounded change feed read, validate that the latest continuation LSN (min across ranges) has reached the planned endLsn. If not, throw OperationCancelledException so the existing transient-IO retry path can recover or fail the task, preventing silent commits of unread bounded intervals. Fixes #49380 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../spark/ChangeFeedPartitionReader.scala | 36 +++++++---- .../TransientIOErrorsRetryingIterator.scala | 49 ++++++++++++++- ...ransientIOErrorsRetryingIteratorSpec.scala | 60 ++++++++++++++++++- 3 files changed, 129 insertions(+), 16 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ChangeFeedPartitionReader.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ChangeFeedPartitionReader.scala index 5d83a5139ef2..048240293fc6 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ChangeFeedPartitionReader.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ChangeFeedPartitionReader.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.types.StructType import java.util +import scala.util.control.NonFatal private object ChangeFeedPartitionReader { val LsnPropertyName: String = LsnAttributeName @@ -64,7 +65,7 @@ private case class ChangeFeedPartitionReader } private val containerTargetConfig = CosmosContainerConfig.parseCosmosContainerConfig(config) - log.logInfo(s"Reading from feed range ${partition.feedRange}, startLsn $getPartitionStartLsn, " + + log.logInfo(s"Reading from feed range ${partition.feedRange}, startLsn ${startLsn.map(_.toString).getOrElse("n/a")}, " + s"endLsn ${partition.endLsn} of " + s"container ${containerTargetConfig.database}.${containerTargetConfig.container}") private val readConfig = CosmosReadConfig.parseCosmosReadConfig(config) @@ -184,19 +185,29 @@ private case class ChangeFeedPartitionReader } } - private def getPartitionStartLsn: Long = { - if (partition.continuationState.isDefined) { - SparkBridgeImplementationInternal.extractLsnFromChangeFeedContinuation(this.partition.continuationState.get) - } else { - 0 + private def getPartitionStartLsn: Option[Long] = { + partition.continuationState.map { continuationState => + try { + val continuationTokens = SparkBridgeImplementationInternal + .extractContinuationTokensFromChangeFeedStateJson(continuationState) + + if (continuationTokens.nonEmpty) { + continuationTokens.minBy(_._2)._2 + } else { + SparkBridgeImplementationInternal.extractLsnFromChangeFeedContinuation(continuationState) + } + } catch { + case NonFatal(_) => + SparkBridgeImplementationInternal.extractLsnFromChangeFeedContinuation(continuationState) + } } } private val changeFeedRequestOptions = { - val startLsn = getPartitionStartLsn + val requestStartLsn = startLsn.map(_.toString).getOrElse("n/a") log.logDebug( - s"Request options for Range '${partition.feedRange.min}-${partition.feedRange.max}' LSN '$startLsn'") + s"Request options for Range '${partition.feedRange.min}-${partition.feedRange.max}' LSN '$requestStartLsn'") val options = CosmosChangeFeedRequestOptions .createForProcessingFromContinuation(this.partition.continuationState.get) @@ -263,7 +274,8 @@ private case class ChangeFeedPartitionReader readConfig.maxItemCount, readConfig.prefetchBufferSize, operationContextAndListenerTuple, - this.partition.endLsn + this.partition.endLsn, + startLsn ) override def next(): Boolean = { @@ -299,11 +311,9 @@ private case class ChangeFeedPartitionReader case None => // for change feed, we would only reach here before the first page got fetched // fallback to use the continuation token from the partition instead - Some(SparkBridgeImplementationInternal - .extractContinuationTokensFromChangeFeedStateJson(partition.continuationState.get) - .minBy(_._2)._2) + startLsn } - if (latestLsnOpt.isDefined) latestLsnOpt.get - startLsn else 0 + latestLsnOpt.flatMap(latestLsn => startLsn.map(latestLsn - _)).getOrElse(0) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIterator.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIterator.scala index c949cd846f6c..14835b1e405b 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIterator.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIterator.scala @@ -11,7 +11,7 @@ import com.azure.cosmos.util.{CosmosPagedFlux, CosmosPagedIterable} import java.util.concurrent.{ExecutorService, SynchronousQueue, ThreadPoolExecutor, TimeUnit, TimeoutException} import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import scala.util.Random -import scala.util.control.Breaks +import scala.util.control.{Breaks, NonFatal} import scala.concurrent.{Await, ExecutionContext, Future} import com.azure.cosmos.implementation.{ChangeFeedSparkRowItem, OperationCancelledException, SparkBridgeImplementationInternal} @@ -41,7 +41,8 @@ private class TransientIOErrorsRetryingIterator[TSparkRow] val pageSize: Int, val pagePrefetchBufferSize: Int, val operationContextAndListener: Option[OperationContextAndListenerTuple], - val endLsn: Option[Long] + val endLsn: Option[Long], + val startLsn: Option[Long] = None ) extends BufferedIterator[TSparkRow] with BasicLoggingTrait with AutoCloseable { private[spark] var maxRetryIntervalInMs = CosmosConstants.maxRetryIntervalForTransientFailuresInMs @@ -177,6 +178,7 @@ private class TransientIOErrorsRetryingIterator[TSparkRow] None } } else { + validateEofProgressOrThrow() Some(false) } } @@ -237,6 +239,49 @@ private class TransientIOErrorsRetryingIterator[TSparkRow] } } + private[this] def validateEofProgressOrThrow(): Unit = { + endLsn.foreach { targetEndLsn => + val latestMinLsn = Option(lastContinuationToken.get()) + .map(extractLatestMinLsnFromContinuation) + + val isEofValid = latestMinLsn match { + case Some(observedLsn) => observedLsn >= targetEndLsn + case None => startLsn.contains(targetEndLsn) + } + + if (!isEofValid) { + val observedLsnText = latestMinLsn.map(_.toString).getOrElse("no page consumed") + val message = s"Bounded change feed read reached EOF before planned endLsn. " + + s"startLsn: $startLsn, endLsn: $targetEndLsn, observed minLatestLsn: $observedLsnText, " + + s"totalChangesCnt: ${totalChangesCnt.get()}, Context: $operationContextString" + + throw new OperationCancelledException(message, null) + } + } + } + + private[this] def extractLatestMinLsnFromContinuation(continuationToken: String): Long = { + try { + val continuationTokens = SparkBridgeImplementationInternal + .extractContinuationTokensFromChangeFeedStateJson(continuationToken) + + if (continuationTokens.nonEmpty) { + continuationTokens.map(_._2).min + } else { + SparkBridgeImplementationInternal.extractLsnFromChangeFeedContinuation(continuationToken) + } + } catch { + case NonFatal(rangeEnumerationFailure) => + try { + SparkBridgeImplementationInternal.extractLsnFromChangeFeedContinuation(continuationToken) + } catch { + case NonFatal(fallbackFailure) => + rangeEnumerationFailure.addSuppressed(fallbackFailure) + throw rangeEnumerationFailure + } + } + } + // Clean up iterator references - the underlying Reactor subscription // from Flux.toIterable() will be cleaned up when the iterator is GC'd override def close(): Unit = { diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorSpec.scala index b8400fdd3eff..302a71678385 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorSpec.scala @@ -3,7 +3,7 @@ package com.azure.cosmos.spark import com.azure.cosmos.CosmosException -import com.azure.cosmos.implementation.SparkRowItem +import com.azure.cosmos.implementation.{OperationCancelledException, SparkRowItem} import com.azure.cosmos.models.{FeedResponse, ModelBridgeInternal} import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait import com.azure.cosmos.util.UtilBridgeInternal @@ -13,6 +13,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode import reactor.core.publisher.Flux import java.time.Duration +import java.util.Base64 import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong @@ -180,6 +181,32 @@ class TransientIOErrorsRetryingIteratorSpec extends UnitSpec with BasicLoggingTr factoryCallCount.get shouldEqual 1 } + "Bounded change feed reads" should + "not complete when the feed ends before the planned end LSN" in { + + val endLsn = 20L + val lastReturnedLsn = 15L + + val response = generateFeedResponse("ChangeFeed", 1, -1) + ModelBridgeInternal.setFeedResponseContinuationToken( + changeFeedContinuation(lastReturnedLsn), + response + ) + + val iterator = new TransientIOErrorsRetryingIterator( + _ => UtilBridgeInternal.createCosmosPagedFlux( + _ => Flux.fromArray(Array(response)) + ), + pageSize, + 1, + None, + Some(endLsn) + ) + iterator.maxRetryCount = 0 + + intercept[OperationCancelledException](iterator.hasNext) + } + private val objectMapper = new ObjectMapper @throws[JsonProcessingException] @@ -340,6 +367,37 @@ class TransientIOErrorsRetryingIteratorSpec extends UnitSpec with BasicLoggingTr } } + private def changeFeedContinuation(lsn: Long): String = { + val state = + s"""{ + | "V": 1, + | "Rid": "testContainer", + | "Mode": "INCREMENTAL", + | "StartFrom": { + | "Type": "BEGINNING" + | }, + | "Continuation": { + | "V": 1, + | "Rid": "testContainer", + | "Continuation": [ + | { + | "token": "$lsn", + | "range": { + | "min": "", + | "max": "FF" + | } + | } + | ], + | "Range": { + | "min": "", + | "max": "FF" + | } + | } + |}""".stripMargin + + Base64.getEncoder.encodeToString(state.getBytes("UTF-8")) + } + private class DummyTransientCosmosException extends CosmosException(500, "Dummy Internal Server Error") From 1c24fd610862248aa31dc3908746ebbf5fcc0a37 Mon Sep 17 00:00:00 2001 From: Annie Liang Date: Fri, 5 Jun 2026 09:40:11 -0700 Subject: [PATCH 2/4] test(cosmos-spark): add bounded change feed EOF validation matrix Covers issue #49380 validation table rows 2-6, plus retry-path and row-level upper-bound regression checks for TransientIOErrorsRetryingIterator. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- ...ransientIOErrorsRetryingIteratorSpec.scala | 228 +++++++++++++++++- 1 file changed, 227 insertions(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorSpec.scala index 302a71678385..b3167a0014af 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorSpec.scala @@ -3,13 +3,14 @@ package com.azure.cosmos.spark import com.azure.cosmos.CosmosException -import com.azure.cosmos.implementation.{OperationCancelledException, SparkRowItem} +import com.azure.cosmos.implementation.{ChangeFeedSparkRowItem, OperationCancelledException, SparkRowItem} import com.azure.cosmos.models.{FeedResponse, ModelBridgeInternal} import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait import com.azure.cosmos.util.UtilBridgeInternal import com.fasterxml.jackson.core.JsonProcessingException import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.node.ObjectNode +import org.apache.spark.sql.Row import reactor.core.publisher.Flux import java.time.Duration @@ -207,6 +208,189 @@ class TransientIOErrorsRetryingIteratorSpec extends UnitSpec with BasicLoggingTr intercept[OperationCancelledException](iterator.hasNext) } + "Bounded change feed reads" should + "complete cleanly when the final continuation reaches the planned end LSN" in { + + // Validation matrix row #2: startLsn=10, endLsn=20, single-range continuation=20 -> complete. + val endLsn = 20L + val lastReturnedLsn = 20L + + val response = generateFeedResponse("ChangeFeed", 1, -1) + ModelBridgeInternal.setFeedResponseContinuationToken( + changeFeedContinuation(lastReturnedLsn), + response + ) + + val iterator = new TransientIOErrorsRetryingIterator( + _ => UtilBridgeInternal.createCosmosPagedFlux( + _ => Flux.fromArray(Array(response)) + ), + pageSize, + 1, + None, + Some(endLsn), + Some(10L) + ) + iterator.maxRetryCount = 0 + + iterator.hasNext shouldEqual false + } + + "Bounded change feed reads" should + "complete cleanly when no page is consumed and startLsn already equals endLsn" in { + + // Validation matrix row #3: startLsn=20, endLsn=20, empty flux -> complete. + val endLsn = 20L + + val iterator = new TransientIOErrorsRetryingIterator[SparkRowItem]( + _ => UtilBridgeInternal.createCosmosPagedFlux(_ => Flux.empty()), + pageSize, + 1, + None, + Some(endLsn), + Some(endLsn) + ) + iterator.maxRetryCount = 0 + + iterator.hasNext shouldEqual false + } + + "Bounded change feed reads" should + "throw when no page is consumed and startLsn is below endLsn" in { + + // Validation matrix row #4: startLsn=10, endLsn=20, empty flux -> throws. + val endLsn = 20L + + val iterator = new TransientIOErrorsRetryingIterator[SparkRowItem]( + _ => UtilBridgeInternal.createCosmosPagedFlux(_ => Flux.empty()), + pageSize, + 1, + None, + Some(endLsn), + Some(10L) + ) + iterator.maxRetryCount = 0 + + intercept[OperationCancelledException](iterator.hasNext) + } + + "Bounded change feed reads" should + "throw when any range in a multi-range continuation lags behind the planned end LSN" in { + + // Validation matrix row #5: continuation [20, 18] (min=18) < endLsn=20 -> throws. + val endLsn = 20L + + val response = generateFeedResponse("ChangeFeed", 1, -1) + ModelBridgeInternal.setFeedResponseContinuationToken( + multiRangeChangeFeedContinuation(Seq(20L, 18L)), + response + ) + + val iterator = new TransientIOErrorsRetryingIterator( + _ => UtilBridgeInternal.createCosmosPagedFlux( + _ => Flux.fromArray(Array(response)) + ), + pageSize, + 1, + None, + Some(endLsn), + Some(10L) + ) + iterator.maxRetryCount = 0 + + intercept[OperationCancelledException](iterator.hasNext) + } + + "Unbounded change feed reads" should + "complete cleanly at EOF without any LSN progress validation" in { + + // Validation matrix row #6: endLsn=None (unbounded), empty flux -> no throw. + // Guards against regression of validation kicking in for batch/unbounded mode. + val iterator = new TransientIOErrorsRetryingIterator[SparkRowItem]( + _ => UtilBridgeInternal.createCosmosPagedFlux(_ => Flux.empty()), + pageSize, + 1, + None, + None, + None + ) + iterator.maxRetryCount = 0 + + iterator.hasNext shouldEqual false + } + + "Bounded change feed reads" should + "retry the bounded EOF failure up to maxRetryCount before propagating it" in { + + // Validation matrix row #7: under-run EOF is treated as a transient (408) failure + // and re-subscribes the underlying flux factory until retries are exhausted. + val endLsn = 20L + val lastReturnedLsn = 15L + val maxRetryCount = 2 + val factoryCallCount = new AtomicLong(0) + + val iterator = new TransientIOErrorsRetryingIterator( + _ => { + factoryCallCount.incrementAndGet() + val response = generateFeedResponse("ChangeFeed", 1, -1) + ModelBridgeInternal.setFeedResponseContinuationToken( + changeFeedContinuation(lastReturnedLsn), + response + ) + UtilBridgeInternal.createCosmosPagedFlux( + _ => Flux.fromArray(Array(response)) + ) + }, + pageSize, + 1, + None, + Some(endLsn), + Some(10L) + ) + iterator.maxRetryCount = maxRetryCount + iterator.maxRetryIntervalInMs = 1 + + intercept[OperationCancelledException](iterator.hasNext) + + // 1 initial attempt + maxRetryCount retries + factoryCallCount.get shouldEqual (1 + maxRetryCount) + } + + "Bounded change feed reads" should + "still suppress rows above endLsn while passing the EOF progress check at the boundary" in { + + // Validation matrix row #8: page contains a row with _lsn > endLsn and the + // final continuation reaches endLsn exactly. validateNextLsn must continue to + // suppress the over-LSN row, and validateEofProgressOrThrow must accept the + // boundary continuation without throwing. + val endLsn = 20L + val rowAboveEndLsn = ChangeFeedSparkRowItem(Row.empty, None, "25") + + val response: FeedResponse[ChangeFeedSparkRowItem] = ModelBridgeInternal + .createFeedResponse( + java.util.Collections.singletonList(rowAboveEndLsn), + new ConcurrentHashMap[String, String] + ) + ModelBridgeInternal.setFeedResponseContinuationToken( + changeFeedContinuation(endLsn), + response + ) + + val iterator = new TransientIOErrorsRetryingIterator[ChangeFeedSparkRowItem]( + _ => UtilBridgeInternal.createCosmosPagedFlux( + _ => Flux.fromArray(Array(response)) + ), + pageSize, + 1, + None, + Some(endLsn), + Some(10L) + ) + iterator.maxRetryCount = 0 + + iterator.hasNext shouldEqual false + } + private val objectMapper = new ObjectMapper @throws[JsonProcessingException] @@ -398,6 +582,48 @@ class TransientIOErrorsRetryingIteratorSpec extends UnitSpec with BasicLoggingTr Base64.getEncoder.encodeToString(state.getBytes("UTF-8")) } + private def multiRangeChangeFeedContinuation(lsns: Seq[Long]): String = { + // Splits the [""..."FF"] range into evenly-sized adjacent sub-ranges, one per supplied LSN. + // The boundaries are arbitrary hex strings; the iterator only inspects the per-range tokens. + require(lsns.nonEmpty, "lsns must contain at least one value") + val boundaries: Seq[String] = "" +: + (1 until lsns.size).map(i => f"${(0xFF * i) / lsns.size}%02X") :+ + "FF" + + val ranges = lsns.zipWithIndex.map { case (lsn, i) => + s"""{ + | "token": "$lsn", + | "range": { + | "min": "${boundaries(i)}", + | "max": "${boundaries(i + 1)}" + | } + |}""".stripMargin + }.mkString(",\n") + + val state = + s"""{ + | "V": 1, + | "Rid": "testContainer", + | "Mode": "INCREMENTAL", + | "StartFrom": { + | "Type": "BEGINNING" + | }, + | "Continuation": { + | "V": 1, + | "Rid": "testContainer", + | "Continuation": [ + | $ranges + | ], + | "Range": { + | "min": "", + | "max": "FF" + | } + | } + |}""".stripMargin + + Base64.getEncoder.encodeToString(state.getBytes("UTF-8")) + } + private class DummyTransientCosmosException extends CosmosException(500, "Dummy Internal Server Error") From 5654b5fddce98f3218a3355dab22db411970e4e9 Mon Sep 17 00:00:00 2001 From: Annie Liang Date: Fri, 5 Jun 2026 10:04:25 -0700 Subject: [PATCH 3/4] =?UTF-8?q?fix(cosmos-spark):=20address=20review=20ite?= =?UTF-8?q?r=201=20=E2=80=94=20parse-failure=20retry=20+=20helper=20consol?= =?UTF-8?q?idation=20+=20diagnostics?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - F1: catch NonFatal in extractLatestMinLsnFromContinuation, wrap as retryable OperationCancelledException - F2: drop startLsn default; require callers to pass it explicitly - F3: consolidate min-LSN-across-ranges parsing into SparkBridgeImplementationInternal - F4: logWarning on first under-run with planned vs observed LSN - F6: tighten getLatestContinuationToken to Option(lastContinuationToken.get()) - F7: logError before throw + clearer diagnostic context - F9: render startLsn/endLsn without raw Option wrapper in exception text - S2: verified continuation includes both child ranges after split (see code comment) - S7: add cancellation-hint to exception message - F8/S8 test additions: malformed continuation, multi-page progression, degenerate endLsn==startLsn with continuation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../SparkBridgeImplementationInternal.scala | 24 ++++ .../spark/ChangeFeedPartitionReader.scala | 18 +-- .../cosmos/spark/ItemsPartitionReader.scala | 1 + .../TransientIOErrorsRetryingIterator.scala | 61 +++++----- ...ansientIOErrorsRetryingIteratorITest.scala | 2 + ...ransientIOErrorsRetryingIteratorSpec.scala | 108 +++++++++++++++++- 6 files changed, 167 insertions(+), 47 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/implementation/SparkBridgeImplementationInternal.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/implementation/SparkBridgeImplementationInternal.scala index f0cb2d8004da..c947b7801392 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/implementation/SparkBridgeImplementationInternal.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/implementation/SparkBridgeImplementationInternal.scala @@ -17,6 +17,7 @@ import com.fasterxml.jackson.databind.ObjectMapper import scala.collection.convert.ImplicitConversions.`list asScalaBuffer` import scala.collection.mutable +import scala.util.control.NonFatal // scalastyle:off underscore.import import com.azure.cosmos.implementation.feedranges._ @@ -147,6 +148,29 @@ private[cosmos] object SparkBridgeImplementationInternal extends BasicLoggingTra .toArray } + def extractMinLatestLsnFromChangeFeedContinuationOrFallback(continuation: String): Long = { + try { + val continuationTokens = extractContinuationTokensFromChangeFeedStateJson(continuation) + + if (continuationTokens.nonEmpty) { + // FeedRangeContinuation.handleFeedRangeGone expands split ranges by adding child tokens, so the + // minimum across the current token set represents the planned feed range after split handling. + continuationTokens.map(_._2).min + } else { + extractLsnFromChangeFeedContinuation(continuation) + } + } catch { + case NonFatal(rangeEnumerationFailure) => + try { + extractLsnFromChangeFeedContinuation(continuation) + } catch { + case NonFatal(fallbackFailure) => + rangeEnumerationFailure.addSuppressed(fallbackFailure) + throw rangeEnumerationFailure + } + } + } + private[cosmos] def rangeToNormalizedRange(rangeInput: Range[String]) = { val range = FeedRangeInternal.normalizeRange(rangeInput) assert(range != null, "Argument 'range' must not be null.") diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ChangeFeedPartitionReader.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ChangeFeedPartitionReader.scala index 048240293fc6..d8cf4b415bcd 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ChangeFeedPartitionReader.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ChangeFeedPartitionReader.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.types.StructType import java.util -import scala.util.control.NonFatal private object ChangeFeedPartitionReader { val LsnPropertyName: String = LsnAttributeName @@ -187,19 +186,7 @@ private case class ChangeFeedPartitionReader private def getPartitionStartLsn: Option[Long] = { partition.continuationState.map { continuationState => - try { - val continuationTokens = SparkBridgeImplementationInternal - .extractContinuationTokensFromChangeFeedStateJson(continuationState) - - if (continuationTokens.nonEmpty) { - continuationTokens.minBy(_._2)._2 - } else { - SparkBridgeImplementationInternal.extractLsnFromChangeFeedContinuation(continuationState) - } - } catch { - case NonFatal(_) => - SparkBridgeImplementationInternal.extractLsnFromChangeFeedContinuation(continuationState) - } + SparkBridgeImplementationInternal.extractMinLatestLsnFromChangeFeedContinuationOrFallback(continuationState) } } @@ -306,8 +293,7 @@ private case class ChangeFeedPartitionReader // for cases where the feed range spans multiple physical partitions // pick the smallest lsn Some(SparkBridgeImplementationInternal - .extractContinuationTokensFromChangeFeedStateJson(continuationToken) - .minBy(_._2)._2) + .extractMinLatestLsnFromChangeFeedContinuationOrFallback(continuationToken)) case None => // for change feed, we would only reach here before the first page got fetched // fallback to use the continuation token from the partition instead diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReader.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReader.scala index 44027bafbe7e..02a49c1a1441 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReader.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReader.scala @@ -256,6 +256,7 @@ private case class ItemsPartitionReader readConfig.maxItemCount, readConfig.prefetchBufferSize, operationContextAndListenerTuple, + None, None ) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIterator.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIterator.scala index 14835b1e405b..0a9c4f492723 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIterator.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIterator.scala @@ -42,7 +42,7 @@ private class TransientIOErrorsRetryingIterator[TSparkRow] val pagePrefetchBufferSize: Int, val operationContextAndListener: Option[OperationContextAndListenerTuple], val endLsn: Option[Long], - val startLsn: Option[Long] = None + val startLsn: Option[Long] ) extends BufferedIterator[TSparkRow] with BasicLoggingTrait with AutoCloseable { private[spark] var maxRetryIntervalInMs = CosmosConstants.maxRetryIntervalForTransientFailuresInMs @@ -74,11 +74,7 @@ private class TransientIOErrorsRetryingIterator[TSparkRow] } def getLatestContinuationToken: Option[String] = { - if (lastContinuationToken == null) { - None - } else { - Some(lastContinuationToken.get()) - } + Option(lastContinuationToken.get()) } override def hasNext: Boolean = { @@ -241,8 +237,19 @@ private class TransientIOErrorsRetryingIterator[TSparkRow] private[this] def validateEofProgressOrThrow(): Unit = { endLsn.foreach { targetEndLsn => - val latestMinLsn = Option(lastContinuationToken.get()) - .map(extractLatestMinLsnFromContinuation) + val latestMinLsn = try { + Option(lastContinuationToken.get()) + .map(extractLatestMinLsnFromContinuation) + } catch { + case NonFatal(parseFailure) => + val message = s"Continuation token parse failure - treating EOF as inconclusive. " + + s"startLsn: ${formatLsn(startLsn)}, endLsn: $targetEndLsn, " + + s"totalChangesCnt: ${totalChangesCnt.get()}, Context: $operationContextString" + val exception = new OperationCancelledException(message, null) + exception.addSuppressed(parseFailure) + logError(message, exception) + throw exception + } val isEofValid = latestMinLsn match { case Some(observedLsn) => observedLsn >= targetEndLsn @@ -252,34 +259,28 @@ private class TransientIOErrorsRetryingIterator[TSparkRow] if (!isEofValid) { val observedLsnText = latestMinLsn.map(_.toString).getOrElse("no page consumed") val message = s"Bounded change feed read reached EOF before planned endLsn. " + - s"startLsn: $startLsn, endLsn: $targetEndLsn, observed minLatestLsn: $observedLsnText, " + - s"totalChangesCnt: ${totalChangesCnt.get()}, Context: $operationContextString" - - throw new OperationCancelledException(message, null) + s"startLsn: ${formatLsn(startLsn)}, endLsn: $targetEndLsn, " + + s"observed minLatestLsn: $observedLsnText, totalChangesCnt: ${totalChangesCnt.get()}, " + + s"Context: $operationContextString. If this occurred during Spark task cancellation/decommission, " + + s"expect the task to retry from the last committed checkpoint. Continuation tokens are expected " + + s"to preserve split child ranges; range-set shrinkage is undefined behavior." + val exception = new OperationCancelledException(message, null) + + // TODO: Consider moving bounded change-feed EOF validation into a dedicated decorator and short-circuiting + // deterministic zero-progress retries once this policy is separated from transient I/O retry handling. + logWarning(message) + logError(message, exception) + throw exception } } } private[this] def extractLatestMinLsnFromContinuation(continuationToken: String): Long = { - try { - val continuationTokens = SparkBridgeImplementationInternal - .extractContinuationTokensFromChangeFeedStateJson(continuationToken) + SparkBridgeImplementationInternal.extractMinLatestLsnFromChangeFeedContinuationOrFallback(continuationToken) + } - if (continuationTokens.nonEmpty) { - continuationTokens.map(_._2).min - } else { - SparkBridgeImplementationInternal.extractLsnFromChangeFeedContinuation(continuationToken) - } - } catch { - case NonFatal(rangeEnumerationFailure) => - try { - SparkBridgeImplementationInternal.extractLsnFromChangeFeedContinuation(continuationToken) - } catch { - case NonFatal(fallbackFailure) => - rangeEnumerationFailure.addSuppressed(fallbackFailure) - throw rangeEnumerationFailure - } - } + private[this] def formatLsn(lsn: Option[Long]): String = { + lsn.map(_.toString).getOrElse("n/a") } // Clean up iterator references - the underlying Reactor subscription diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorITest.scala index 1eea43a57daa..a4f1d1e958a3 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorITest.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorITest.scala @@ -113,6 +113,7 @@ class TransientIOErrorsRetryingIteratorITest 2, Queues.XS_BUFFER_SIZE, None, + None, None ) retryingIterator.maxRetryIntervalInMs = 5 @@ -255,6 +256,7 @@ class TransientIOErrorsRetryingIteratorITest 2, Queues.XS_BUFFER_SIZE, None, + None, None ) retryingIterator.maxRetryIntervalInMs = 5 diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorSpec.scala index b3167a0014af..dbad259a8c9d 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIteratorSpec.scala @@ -47,6 +47,7 @@ class TransientIOErrorsRetryingIteratorSpec extends UnitSpec with BasicLoggingTr pageSize, 1, None, + None, None ) iterator.maxRetryIntervalInMs = 5 @@ -67,6 +68,7 @@ class TransientIOErrorsRetryingIteratorSpec extends UnitSpec with BasicLoggingTr pageSize, 1, None, + None, None ) iterator.maxRetryIntervalInMs = 5 @@ -87,6 +89,7 @@ class TransientIOErrorsRetryingIteratorSpec extends UnitSpec with BasicLoggingTr pageSize, 1, None, + None, None ) iterator.maxRetryIntervalInMs = 5 @@ -110,6 +113,7 @@ class TransientIOErrorsRetryingIteratorSpec extends UnitSpec with BasicLoggingTr pageSize, 1, None, + None, None ) iterator.maxRetryIntervalInMs = 5 @@ -135,6 +139,7 @@ class TransientIOErrorsRetryingIteratorSpec extends UnitSpec with BasicLoggingTr pageSize, 1, None, + None, None ) @@ -167,6 +172,7 @@ class TransientIOErrorsRetryingIteratorSpec extends UnitSpec with BasicLoggingTr pageSize, 1, None, + None, None ) @@ -201,7 +207,8 @@ class TransientIOErrorsRetryingIteratorSpec extends UnitSpec with BasicLoggingTr pageSize, 1, None, - Some(endLsn) + Some(endLsn), + Some(10L) ) iterator.maxRetryCount = 0 @@ -391,6 +398,101 @@ class TransientIOErrorsRetryingIteratorSpec extends UnitSpec with BasicLoggingTr iterator.hasNext shouldEqual false } + "Bounded change feed reads" should + "treat malformed continuation EOF as retryable inconclusive progress" in { + + val endLsn = 20L + val response = generateFeedResponse("ChangeFeed", 1, -1) + ModelBridgeInternal.setFeedResponseContinuationToken( + malformedChangeFeedContinuation(), + response + ) + + val iterator = new TransientIOErrorsRetryingIterator( + _ => UtilBridgeInternal.createCosmosPagedFlux( + _ => Flux.fromArray(Array(response)) + ), + pageSize, + 1, + None, + Some(endLsn), + Some(10L) + ) + iterator.maxRetryCount = 0 + + val error = intercept[OperationCancelledException](iterator.hasNext) + error.getStatusCode shouldEqual 408 + error.getMessage should include("Continuation token parse failure") + error.getSuppressed.length should be > 0 + } + + "Bounded change feed reads" should + "complete cleanly when multi-page progress reaches the planned end LSN" in { + + val endLsn = 20L + val rowAtLsn12 = ChangeFeedSparkRowItem(Row.empty, None, "12") + val firstResponse: FeedResponse[ChangeFeedSparkRowItem] = ModelBridgeInternal + .createFeedResponse( + java.util.Collections.singletonList(rowAtLsn12), + new ConcurrentHashMap[String, String] + ) + ModelBridgeInternal.setFeedResponseContinuationToken( + changeFeedContinuation(12L), + firstResponse + ) + + val secondResponse: FeedResponse[ChangeFeedSparkRowItem] = ModelBridgeInternal + .createFeedResponse( + java.util.Collections.emptyList[ChangeFeedSparkRowItem](), + new ConcurrentHashMap[String, String] + ) + ModelBridgeInternal.setFeedResponseContinuationToken( + changeFeedContinuation(endLsn), + secondResponse + ) + + val iterator = new TransientIOErrorsRetryingIterator[ChangeFeedSparkRowItem]( + _ => UtilBridgeInternal.createCosmosPagedFlux( + _ => Flux.fromArray(Array(firstResponse, secondResponse)) + ), + pageSize, + 1, + None, + Some(endLsn), + Some(10L) + ) + iterator.maxRetryCount = 0 + + iterator.hasNext shouldEqual true + iterator.next() shouldEqual rowAtLsn12 + iterator.hasNext shouldEqual false + } + + "Bounded change feed reads" should + "complete cleanly when startLsn equals endLsn and a page continuation reaches the same LSN" in { + + val endLsn = 20L + val response = generateFeedResponse("ChangeFeed", 1, -1) + ModelBridgeInternal.setFeedResponseContinuationToken( + changeFeedContinuation(endLsn), + response + ) + + val iterator = new TransientIOErrorsRetryingIterator( + _ => UtilBridgeInternal.createCosmosPagedFlux( + _ => Flux.fromArray(Array(response)) + ), + pageSize, + 1, + None, + Some(endLsn), + Some(endLsn) + ) + iterator.maxRetryCount = 0 + + iterator.hasNext shouldEqual false + } + private val objectMapper = new ObjectMapper @throws[JsonProcessingException] @@ -582,6 +684,10 @@ class TransientIOErrorsRetryingIteratorSpec extends UnitSpec with BasicLoggingTr Base64.getEncoder.encodeToString(state.getBytes("UTF-8")) } + private def malformedChangeFeedContinuation(): String = { + Base64.getEncoder.encodeToString("not-json".getBytes("UTF-8")) + } + private def multiRangeChangeFeedContinuation(lsns: Seq[Long]): String = { // Splits the [""..."FF"] range into evenly-sized adjacent sub-ranges, one per supplied LSN. // The boundaries are arbitrary hex strings; the iterator only inspects the per-range tokens. From 05a96a881e4cc05b55b3b34d33925a4cf62468cd Mon Sep 17 00:00:00 2001 From: Annie Liang Date: Fri, 5 Jun 2026 10:23:06 -0700 Subject: [PATCH 4/4] Address iter-2 review: drop duplicate log + inline helper wrapper - Removed redundant logWarning before logError on EOF under-run path (deep reviewer F11; previously logged the same message twice and could double-count on log-based alerting). - Inlined SparkBridgeImplementationInternal.extractMinLatestLsnFrom- ChangeFeedContinuationOrFallback at the call site and removed the one-line extractLatestMinLsnFromContinuation private wrapper (deep reviewer F12). - Skeptic Lens NonFatal/InterruptedException concern verified factually incorrect: scala.util.control.NonFatal.apply explicitly excludes InterruptedException, so the wrap cannot convert cancellation into a retry loop. - Caller-site audit verified: all 4 production + 2 test call sites pass startLsn explicitly. Shaded azure-cosmos-spark_4-1_2-13 uses generated target/shared-sources/ mirrors that sync automatically. 17/17 unit tests still pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../cosmos/spark/TransientIOErrorsRetryingIterator.scala | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIterator.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIterator.scala index 0a9c4f492723..c99aebe02f80 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIterator.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingIterator.scala @@ -239,7 +239,7 @@ private class TransientIOErrorsRetryingIterator[TSparkRow] endLsn.foreach { targetEndLsn => val latestMinLsn = try { Option(lastContinuationToken.get()) - .map(extractLatestMinLsnFromContinuation) + .map(SparkBridgeImplementationInternal.extractMinLatestLsnFromChangeFeedContinuationOrFallback) } catch { case NonFatal(parseFailure) => val message = s"Continuation token parse failure - treating EOF as inconclusive. " + @@ -268,17 +268,12 @@ private class TransientIOErrorsRetryingIterator[TSparkRow] // TODO: Consider moving bounded change-feed EOF validation into a dedicated decorator and short-circuiting // deterministic zero-progress retries once this policy is separated from transient I/O retry handling. - logWarning(message) logError(message, exception) throw exception } } } - private[this] def extractLatestMinLsnFromContinuation(continuationToken: String): Long = { - SparkBridgeImplementationInternal.extractMinLatestLsnFromChangeFeedContinuationOrFallback(continuationToken) - } - private[this] def formatLsn(lsn: Option[Long]): String = { lsn.map(_.toString).getOrElse("n/a") }