Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.{SparkTaskUtil, TaskCompletionListener, TaskFailureListener}

import java.util
import java.util.{Collections, Properties, UUID}
import java.util.{Properties, UUID}
import java.util.concurrent.atomic.AtomicLong

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.compat.Platform.ConcurrentModificationException

object TaskResources extends TaskListener with Logging {
Expand Down Expand Up @@ -249,9 +248,9 @@ object TaskResources extends TaskListener with Logging {
// thread safe
class TaskResourceRegistry extends Logging {
private val sharedUsage = new SimpleMemoryUsageRecorder()
private val resources = new util.HashMap[String, TaskResource]()
private val priorityToResourcesMapping: util.HashMap[Int, util.LinkedHashSet[TaskResource]] =
new util.HashMap[Int, util.LinkedHashSet[TaskResource]]()
private val resources = mutable.Map.empty[String, TaskResource]
private val priorityToResourcesMapping: mutable.Map[Int, mutable.LinkedHashSet[TaskResource]] =
mutable.Map.empty[Int, mutable.LinkedHashSet[TaskResource]]

private var exclusiveLockAcquired: Boolean = false
private def lock[T](body: => T): T = {
Expand Down Expand Up @@ -279,7 +278,7 @@ class TaskResourceRegistry extends Logging {
private def addResource0(id: String, resource: TaskResource): Unit = lock {
resources.put(id, resource)
priorityToResourcesMapping
.computeIfAbsent(resource.priority(), _ => new util.LinkedHashSet[TaskResource]())
.getOrElseUpdate(resource.priority(), mutable.LinkedHashSet.empty[TaskResource])
.add(resource)
}

Expand All @@ -290,42 +289,24 @@ class TaskResourceRegistry extends Logging {

/** Release all managed resources according to priority and reversed order */
private[task] def releaseAll(): Unit = lock {
val table = new util.ArrayList(priorityToResourcesMapping.entrySet())
Collections.sort(
table,
(
o1: util.Map.Entry[Int, util.LinkedHashSet[TaskResource]],
o2: util.Map.Entry[Int, util.LinkedHashSet[TaskResource]]) => {
val diff = o2.getKey - o1.getKey // descending by priority
if (diff > 0) {
1
} else if (diff < 0) {
-1
} else {
throw new IllegalStateException(
"Unreachable code from org.apache.spark.task.TaskResourceRegistry.releaseAll")
}
}
)
table.forEach {
_.getValue.asScala.toSeq.reverse
.foreach(release(_)) // lifo for all resources within the same priority
priorityToResourcesMapping.toSeq.sortBy(-_._1).foreach {
case (_, resources) =>
resources.toSeq.reverse.foreach(release)
}
priorityToResourcesMapping.clear()
resources.clear()
}

/** Release single resource by ID */
private[task] def releaseResource(id: String): Unit = lock {
if (!resources.containsKey(id)) {
val resource = resources.getOrElse(
id,
throw new IllegalArgumentException(
String.format("TaskResource with ID %s is not registered", id))
}
val resource = resources.get(id)
if (!priorityToResourcesMapping.containsKey(resource.priority())) {
throw new IllegalStateException("TaskResource's priority not found in priority mapping")
}
val samePrio = priorityToResourcesMapping.get(resource.priority())
String.format("TaskResource with ID %s is not registered", id)))
val samePrio = priorityToResourcesMapping.getOrElse(
resource.priority(),
throw new IllegalStateException("TaskResource's priority not found in priority mapping"))

if (!samePrio.contains(resource)) {
throw new IllegalStateException("TaskResource not found in priority mapping")
}
Expand All @@ -336,16 +317,18 @@ class TaskResourceRegistry extends Logging {

private[task] def addResourceIfNotRegistered[T <: TaskResource](id: String, factory: () => T): T =
lock {
if (resources.containsKey(id)) {
return resources.get(id).asInstanceOf[T]
}
val resource = factory.apply()
addResource0(id, resource)
resource
resources
.getOrElse(
id, {
val resource = factory.apply()
addResource0(id, resource)
resource
})
.asInstanceOf[T]
}

private[task] def addResource[T <: TaskResource](id: String, resource: T): T = lock {
if (resources.containsKey(id)) {
if (resources.contains(id)) {
throw new IllegalArgumentException(
String.format("TaskResource with ID %s is already registered", id))
}
Expand All @@ -354,15 +337,16 @@ class TaskResourceRegistry extends Logging {
}

private[task] def isResourceRegistered(id: String): Boolean = lock {
resources.containsKey(id)
resources.contains(id)
}

private[task] def getResource[T <: TaskResource](id: String): T = lock {
if (!resources.containsKey(id)) {
throw new IllegalArgumentException(
String.format("TaskResource with ID %s is not registered", id))
}
resources.get(id).asInstanceOf[T]
resources
.getOrElse(
id,
throw new IllegalArgumentException(
String.format("TaskResource with ID %s is not registered", id)))
.asInstanceOf[T]
}

private[task] def getSharedUsage(): SimpleMemoryUsageRecorder = lock {
Expand Down