1616 */
1717package org .apache .spark .storage
1818
19- import java .io .{DataOutputStream , File , FileOutputStream , InputStream , IOException }
19+ import java .io .{DataOutputStream , File , FileNotFoundException , FileOutputStream , InputStream , IOException }
2020import java .nio .file .Files
2121
2222import scala .concurrent .duration ._
2323import scala .util .Random
2424
2525import org .apache .hadoop .conf .Configuration
26- import org .apache .hadoop .fs .{FSDataInputStream , LocalFileSystem , Path , PositionedReadable , Seekable }
26+ import org .apache .hadoop .fs .{FileSystem , FSDataInputStream , LocalFileSystem , Path , PositionedReadable , Seekable }
2727import org .mockito .{ArgumentMatchers => mc }
2828import org .mockito .Mockito .{mock , never , verify , when }
2929import org .scalatest .concurrent .Eventually .{eventually , interval , timeout }
3030
3131import org .apache .spark .{LocalSparkContext , SparkConf , SparkContext , SparkFunSuite , TestUtils }
3232import org .apache .spark .LocalSparkContext .withSpark
33+ import org .apache .spark .deploy .SparkHadoopUtil
3334import org .apache .spark .internal .config ._
3435import org .apache .spark .io .CompressionCodec
3536import org .apache .spark .launcher .SparkLauncher .{EXECUTOR_MEMORY , SPARK_MASTER }
@@ -39,6 +40,7 @@ import org.apache.spark.scheduler.ExecutorDecommissionInfo
3940import org .apache .spark .scheduler .cluster .StandaloneSchedulerBackend
4041import org .apache .spark .shuffle .{IndexShuffleBlockResolver , ShuffleBlockInfo }
4142import org .apache .spark .shuffle .IndexShuffleBlockResolver .NOOP_REDUCE_ID
43+ import org .apache .spark .util .Clock
4244import org .apache .spark .util .Utils .tryWithResource
4345
4446class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext {
@@ -334,7 +336,44 @@ class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext {
334336 }
335337 }
336338 }
339+
340+ Seq (0 , 1 , 3 , 6 ).foreach { replicationSeconds =>
341+ test(s " Consider replication delay - ${replicationSeconds}s " ) {
342+ val replicationMs = replicationSeconds * 1000 ;
343+ val delay = 5 // max allowed replication (in seconds)
344+ val wait = 2 // time between open file attempts (in seconds)
345+ val conf = getSparkConf()
346+ .set(STORAGE_DECOMMISSION_FALLBACK_STORAGE_REPLICATION_DELAY .key, s " ${delay}s " )
347+ .set(STORAGE_DECOMMISSION_FALLBACK_STORAGE_REPLICATION_WAIT .key, s " ${wait}s " )
348+
349+ val filesystem = FileSystem .get(SparkHadoopUtil .get.newConfiguration(conf))
350+ val path = new Path (conf.get(STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH ).get, " file" )
351+ val startMs = 123000000L * 1000L // arbitrary system time
352+ val clock = new DelayedActionClock (replicationMs, startMs)(filesystem.create(path).close())
353+
354+ if (replicationSeconds <= delay) {
355+ // expect open to succeed
356+ val in = FallbackStorage .open(conf, filesystem, path, clock)
357+ assert(in != null )
358+
359+ // how many waits are expected to observe replication
360+ val expectedWaits = Math .ceil(replicationSeconds.toFloat / wait).toInt
361+ assert(clock.timeMs == startMs + expectedWaits * wait * 1000 )
362+ assert(clock.waited == expectedWaits)
363+ in.close()
364+ } else {
365+ // expect open to fail
366+ assertThrows[FileNotFoundException ](FallbackStorage .open(conf, filesystem, path, clock))
367+
368+ // how many waits are expected to observe delay
369+ val expectedWaits = delay / wait
370+ assert(clock.timeMs == startMs + expectedWaits * wait * 1000 )
371+ assert(clock.waited == expectedWaits)
372+ }
373+ }
374+ }
337375}
376+
338377class ReadPartialInputStream (val in : FSDataInputStream ) extends InputStream
339378 with Seekable with PositionedReadable {
340379 override def read : Int = in.read
@@ -378,3 +417,30 @@ class ReadPartialFileSystem extends LocalFileSystem {
378417 new FSDataInputStream (new ReadPartialInputStream (stream))
379418 }
380419}
420+
421+ class DelayedActionClock (delayMs : Long , startTimeMs : Long )(action : => Unit ) extends Clock {
422+ var timeMs : Long = startTimeMs
423+ var waited : Int = 0
424+ var triggered : Boolean = false
425+
426+ if (delayMs == 0 ) trigger()
427+
428+ private def trigger (): Unit = {
429+ if (! triggered) {
430+ triggered = true
431+ action
432+ }
433+ }
434+
435+ override def getTimeMillis (): Long = timeMs
436+ override def nanoTime (): Long = timeMs * 1000000
437+ override def waitTillTime (targetTime : Long ): Long = {
438+ waited += 1
439+ if (targetTime >= startTimeMs + delayMs) {
440+ timeMs = startTimeMs + delayMs
441+ trigger()
442+ }
443+ timeMs = targetTime
444+ targetTime
445+ }
446+ }
0 commit comments