Skip to content

Commit 9b32334

Browse files
Ngone51dongjoon-hyun
authored andcommitted
[SPARK-50768][SQL][CORE][FOLLOW-UP] Apply TaskContext.createResourceUninterruptibly() to risky resource creations
### What changes were proposed in this pull request? This is a follow-up PR for apache#49413. This PR intends to apply `TaskContext.createResourceUninterruptibly()` to the resource creation where it has the potential risk of resource leak in the case of task cancellation. ### Why are the changes needed? Avoid resource leak. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? n/a ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#49508 from Ngone51/SPARK-50768-followup. Authored-by: Yi Wu <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 90801c2 commit 9b32334

File tree

11 files changed

+73
-39
lines changed

11 files changed

+73
-39
lines changed

core/src/main/scala/org/apache/spark/BarrierTaskContext.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,12 +278,12 @@ class BarrierTaskContext private[spark] (
278278
override private[spark] def interruptible(): Boolean = taskContext.interruptible()
279279

280280
override private[spark] def pendingInterrupt(threadToInterrupt: Option[Thread], reason: String)
281-
: Unit = {
281+
: Unit = {
282282
taskContext.pendingInterrupt(threadToInterrupt, reason)
283283
}
284284

285285
override private[spark] def createResourceUninterruptibly[T <: Closeable](resourceBuilder: => T)
286-
: T = {
286+
: T = {
287287
taskContext.createResourceUninterruptibly(resourceBuilder)
288288
}
289289
}

core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,13 @@ class NewHadoopRDD[K, V](
244244
private var finished = false
245245
private var reader =
246246
try {
247-
Utils.tryInitializeResource(
248-
format.createRecordReader(split.serializableHadoopSplit.value, hadoopAttemptContext)
249-
) { reader =>
250-
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
251-
reader
247+
Utils.createResourceUninterruptiblyIfInTaskThread {
248+
Utils.tryInitializeResource(
249+
format.createRecordReader(split.serializableHadoopSplit.value, hadoopAttemptContext)
250+
) { reader =>
251+
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
252+
reader
253+
}
252254
}
253255
} catch {
254256
case e: FileNotFoundException if ignoreMissingFiles =>

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3086,6 +3086,18 @@ private[spark] object Utils
30863086
files.toSeq
30873087
}
30883088

3089+
/**
3090+
* Create a resource uninterruptibly if we are in a task thread (i.e., TaskContext.get() != null).
3091+
* Otherwise, create the resource normally. This is mainly used in the situation where we want to
3092+
* create a multi-layer resource in a task thread. The uninterruptible behavior ensures we don't
3093+
* leak the underlying resources when there is a task cancellation request,
3094+
*/
3095+
def createResourceUninterruptiblyIfInTaskThread[R <: Closeable](createResource: => R): R = {
3096+
Option(TaskContext.get()).map(_.createResourceUninterruptibly {
3097+
createResource
3098+
}).getOrElse(createResource)
3099+
}
3100+
30893101
/**
30903102
* Return the median number of a long array
30913103
*

mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils
3838
import org.apache.spark.sql.execution.datasources._
3939
import org.apache.spark.sql.sources._
4040
import org.apache.spark.sql.types._
41-
import org.apache.spark.util.SerializableConfiguration
41+
import org.apache.spark.util.{SerializableConfiguration, Utils}
4242

4343
private[libsvm] class LibSVMOutputWriter(
4444
val path: String,
@@ -156,7 +156,9 @@ private[libsvm] class LibSVMFileFormat
156156
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
157157

158158
(file: PartitionedFile) => {
159-
val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
159+
val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread(
160+
new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
161+
)
160162
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
161163

162164
val points = linesReader

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.SQLExecution
4242
import org.apache.spark.sql.execution.datasources._
4343
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
4444
import org.apache.spark.sql.types.StructType
45+
import org.apache.spark.util.Utils
4546

4647
/**
4748
* Common functions for parsing CSV files
@@ -99,7 +100,9 @@ object TextInputCSVDataSource extends CSVDataSource {
99100
headerChecker: CSVHeaderChecker,
100101
requiredSchema: StructType): Iterator[InternalRow] = {
101102
val lines = {
102-
val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
103+
val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread(
104+
new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
105+
)
103106
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
104107
linesReader.map { line =>
105108
new String(line.getBytes, 0, line.getLength, parser.options.charset)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ object TextInputJsonDataSource extends JsonDataSource {
129129
file: PartitionedFile,
130130
parser: JacksonParser,
131131
schema: StructType): Iterator[InternalRow] = {
132-
val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
132+
val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread(
133+
new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
134+
)
133135
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
134136
val textParser = parser.options.encoding
135137
.map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text))
@@ -211,7 +213,9 @@ object MultiLineJsonDataSource extends JsonDataSource {
211213
schema: StructType): Iterator[InternalRow] = {
212214
def partitionedFileString(ignored: Any): UTF8String = {
213215
Utils.tryWithResource {
214-
CodecStreams.createInputStreamWithCloseResource(conf, file.toPath)
216+
Utils.createResourceUninterruptiblyIfInTaskThread {
217+
CodecStreams.createInputStreamWithCloseResource(conf, file.toPath)
218+
}
215219
} { inputStream =>
216220
UTF8String.fromBytes(ByteStreams.toByteArray(inputStream))
217221
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
3232
import org.apache.spark.sql.execution.datasources._
3333
import org.apache.spark.sql.sources._
3434
import org.apache.spark.sql.types.{DataType, StringType, StructType}
35-
import org.apache.spark.util.SerializableConfiguration
35+
import org.apache.spark.util.{SerializableConfiguration, Utils}
3636

3737
/**
3838
* A data source for reading text files. The text files must be encoded as UTF-8.
@@ -119,10 +119,12 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
119119

120120
(file: PartitionedFile) => {
121121
val confValue = conf.value.value
122-
val reader = if (!textOptions.wholeText) {
123-
new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue)
124-
} else {
125-
new HadoopFileWholeTextReader(file, confValue)
122+
val reader = Utils.createResourceUninterruptiblyIfInTaskThread {
123+
if (!textOptions.wholeText) {
124+
new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue)
125+
} else {
126+
new HadoopFileWholeTextReader(file, confValue)
127+
}
126128
}
127129
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => reader.close()))
128130
if (requiredSchema.isEmpty) {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -261,21 +261,23 @@ case class ParquetPartitionReaderFactory(
261261
val int96RebaseSpec = DataSourceUtils.int96RebaseSpec(
262262
footerFileMetaData.getKeyValueMetaData.get,
263263
int96RebaseModeInRead)
264-
Utils.tryInitializeResource(
265-
buildReaderFunc(
266-
file.partitionValues,
267-
pushed,
268-
convertTz,
269-
datetimeRebaseSpec,
270-
int96RebaseSpec)
271-
) { reader =>
272-
reader match {
273-
case vectorizedReader: VectorizedParquetRecordReader =>
274-
vectorizedReader.initialize(split, hadoopAttemptContext, Option.apply(fileFooter))
275-
case _ =>
276-
reader.initialize(split, hadoopAttemptContext)
264+
Utils.createResourceUninterruptiblyIfInTaskThread {
265+
Utils.tryInitializeResource(
266+
buildReaderFunc(
267+
file.partitionValues,
268+
pushed,
269+
convertTz,
270+
datetimeRebaseSpec,
271+
int96RebaseSpec)
272+
) { reader =>
273+
reader match {
274+
case vectorizedReader: VectorizedParquetRecordReader =>
275+
vectorizedReader.initialize(split, hadoopAttemptContext, Option.apply(fileFooter))
276+
case _ =>
277+
reader.initialize(split, hadoopAttemptContext)
278+
}
279+
reader
277280
}
278-
reader
279281
}
280282
}
281283

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextPartitionReaderFactory.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.text.TextOptions
2727
import org.apache.spark.sql.execution.datasources.v2._
2828
import org.apache.spark.sql.internal.SQLConf
2929
import org.apache.spark.sql.types.StructType
30-
import org.apache.spark.util.SerializableConfiguration
30+
import org.apache.spark.util.{SerializableConfiguration, Utils}
3131

3232
/**
3333
* A factory used to create Text readers.
@@ -47,10 +47,12 @@ case class TextPartitionReaderFactory(
4747

4848
override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
4949
val confValue = broadcastedConf.value.value
50-
val reader = if (!options.wholeText) {
51-
new HadoopFileLinesReader(file, options.lineSeparatorInRead, confValue)
52-
} else {
53-
new HadoopFileWholeTextReader(file, confValue)
50+
val reader = Utils.createResourceUninterruptiblyIfInTaskThread {
51+
if (!options.wholeText) {
52+
new HadoopFileLinesReader(file, options.lineSeparatorInRead, confValue)
53+
} else {
54+
new HadoopFileWholeTextReader(file, confValue)
55+
}
5456
}
5557
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => reader.close()))
5658
val iter = if (readDataSchema.isEmpty) {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.SQLExecution
4242
import org.apache.spark.sql.execution.datasources._
4343
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
4444
import org.apache.spark.sql.types.StructType
45+
import org.apache.spark.util.Utils
4546

4647
/**
4748
* Common functions for parsing XML files
@@ -97,7 +98,9 @@ object TextInputXmlDataSource extends XmlDataSource {
9798
parser: StaxXmlParser,
9899
schema: StructType): Iterator[InternalRow] = {
99100
val lines = {
100-
val linesReader = new HadoopFileLinesReader(file, None, conf)
101+
val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread(
102+
new HadoopFileLinesReader(file, None, conf)
103+
)
101104
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
102105
linesReader.map { line =>
103106
new String(line.getBytes, 0, line.getLength, parser.options.charset)

0 commit comments

Comments
 (0)