diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 7918d1618eb0..437046d2c0a6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -23,7 +23,7 @@ import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, FallbackStorage, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -88,7 +88,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED), SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), readMetrics, - fetchContinuousBlocksInBatch).toCompletionIterator + fetchContinuousBlocksInBatch, + FallbackStorage.getFallbackStorage(SparkEnv.get.conf)).toCompletionIterator val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala b/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala index 19cdebd80ebf..ba10e63f5e24 100644 --- a/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala +++ b/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala @@ -90,6 +90,11 @@ private[storage] class FallbackStorage(conf: SparkConf) extends Logging { } } + /** + * Read a ManagedBuffer. + */ + def read(blockId: BlockId): ManagedBuffer = FallbackStorage.read(conf, blockId) + def exists(shuffleId: Int, filename: String): Boolean = { val hash = JavaUtils.nonNegativeHash(filename) fallbackFileSystem.exists(new Path(fallbackPath, s"$appId/$shuffleId/$hash/$filename")) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b2f185bc590f..8ca35aa1f3a4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -101,6 +101,7 @@ final class ShuffleBlockFetcherIterator( checksumAlgorithm: String, shuffleMetrics: ShuffleReadMetricsReporter, doBatchFetch: Boolean, + fallbackStorage: Option[FallbackStorage], clock: Clock = new SystemClock()) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { @@ -973,14 +974,28 @@ final class ShuffleBlockFetcherIterator( } case FailureFetchResult(blockId, mapIndex, address, e) => + var error = e var errorMsg: String = null if (e.isInstanceOf[OutOfDirectMemoryError]) { val logMessage = log"Block ${MDC(BLOCK_ID, blockId)} fetch failed after " + log"${MDC(MAX_ATTEMPTS, maxAttemptsOnNettyOOM)} retries due to Netty OOM" logError(logMessage) errorMsg = logMessage.message + } else if (fallbackStorage.isDefined) { + try { + val buf = fallbackStorage.get.read(blockId) + results.put(SuccessFetchResult(blockId, mapIndex, address, buf.size(), buf, + isNetworkReqDone = false)) + result = null + error = null + } catch { + case t: Throwable => + logInfo(s"Failed to read block from fallback storage: $blockId", t) + } + } + if (error != null) { + throwFetchFailedException(blockId, mapIndex, address, error, Some(errorMsg)) } - throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) case DeferFetchRequestResult(request) => val address = request.address diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 211de2e8729e..3aba0fb8158e 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -196,7 +196,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { checksumEnabled: Boolean = true, checksumAlgorithm: String = "ADLER32", shuffleMetrics: Option[ShuffleReadMetricsReporter] = None, - doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = { + doBatchFetch: Boolean = false, + fallbackStorage: Option[FallbackStorage] = None): ShuffleBlockFetcherIterator = { val tContext = taskContext.getOrElse(TaskContext.empty()) new ShuffleBlockFetcherIterator( tContext, @@ -222,7 +223,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { checksumEnabled, checksumAlgorithm, shuffleMetrics.getOrElse(tContext.taskMetrics().createTempShuffleReadMetrics()), - doBatchFetch) + doBatchFetch, + fallbackStorage) } // scalastyle:on argcount @@ -1127,6 +1129,54 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { assert(e.getMessage.contains("fetch failed after 10 retries due to Netty OOM")) } + test("SPARK-52507: missing blocks attempts to read from fallback storage") { + val blockManager = createMockBlockManager() + + configureMockTransfer(Map.empty) + val remoteBmId = BlockManagerId("test-remote-client-1", "test-remote-host", 2) + val blockId = ShuffleBlockId(0, 0, 0) + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (remoteBmId, Seq((blockId, 1L, 0))) + ) + + // iterator with no FallbackStorage cannot find the block + { + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress = blocksByAddress) + val e = intercept[FetchFailedException] { + iterator.next() + } + assert(e.getCause != null) + assert(e.getCause.isInstanceOf[BlockNotFoundException]) + assert(e.getCause.getMessage.contains("Block shuffle_0_0_0 not found")) + } + + // iterator with FallbackStorage that does not store the block cannot find it either + val fallbackStorage = mock(classOf[FallbackStorage]) + + { + when(fallbackStorage.read(ShuffleBlockId(0, 0, 1))).thenReturn(new TestManagedBuffer(127)) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress = blocksByAddress, + fallbackStorage = Some(fallbackStorage)) + val e = intercept[FetchFailedException] { + iterator.next() + } + assert(e.getCause != null) + assert(e.getCause.isInstanceOf[BlockNotFoundException]) + assert(e.getCause.getMessage.contains("Block shuffle_0_0_0 not found")) + } + + // iterator with FallbackStorage that stores the block can find it + { + when(fallbackStorage.read(ShuffleBlockId(0, 0, 0))).thenReturn(new TestManagedBuffer(127)) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress = blocksByAddress, + fallbackStorage = Some(fallbackStorage)) + assert(iterator.hasNext) + val (id, _) = iterator.next() + assert(id === ShuffleBlockId(0, 0, 0)) + assert(!iterator.hasNext) + } + } + /** * Prepares the transfer to trigger success for all the blocks present in blockChunks. It will * trigger failure of block which is not part of blockChunks.