Skip to content

Commit d4ed074

Browse files
committed
Fix blocking of foreach() if the waiting thread is interrupted.
1 parent 5d65ccb commit d4ed074

File tree

3 files changed

+139
-16
lines changed

3 files changed

+139
-16
lines changed

src/main/scala/com/github/yruslan/channel/Channel.scala

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,24 @@ abstract class Channel[T] extends ReadChannel[T] with WriteChannel[T] {
7777
@throws[InterruptedException]
7878
final override def foreach[U](f: T => U): Unit = {
7979
while (true) {
80+
var valOpt: Option[T] = None
81+
8082
lock.lock()
81-
readers += 1
82-
while (!closed && !hasMessages) {
83-
awaitReaders()
84-
}
85-
readers -= 1
86-
if (isClosed) {
83+
try {
84+
readers += 1
85+
while (!closed && !hasMessages) {
86+
awaitReaders()
87+
}
88+
readers -= 1
89+
if (isClosed) {
90+
return
91+
}
92+
93+
valOpt = fetchValueOpt()
94+
} finally {
8795
lock.unlock()
88-
return
8996
}
9097

91-
val valOpt = fetchValueOpt()
92-
lock.unlock()
93-
9498
valOpt.foreach(f)
9599
}
96100
}
@@ -187,6 +191,7 @@ abstract class Channel[T] extends ReadChannel[T] with WriteChannel[T] {
187191

188192
protected def hasMessages: Boolean
189193

194+
/* This method assumes the lock is being held. */
190195
@throws[InterruptedException]
191196
final protected def awaitWriters(): Unit = {
192197
try {
@@ -199,6 +204,7 @@ abstract class Channel[T] extends ReadChannel[T] with WriteChannel[T] {
199204
}
200205
}
201206

207+
/* This method assumes the lock is being held. */
202208
@throws[InterruptedException]
203209
final protected def awaitWriters(awaiter: Awaiter): Boolean = {
204210
try {
@@ -211,6 +217,7 @@ abstract class Channel[T] extends ReadChannel[T] with WriteChannel[T] {
211217
}
212218
}
213219

220+
/* This method assumes the lock is being held. */
214221
@throws[InterruptedException]
215222
final protected def awaitReaders(): Unit = {
216223
try {
@@ -223,6 +230,7 @@ abstract class Channel[T] extends ReadChannel[T] with WriteChannel[T] {
223230
}
224231
}
225232

233+
/* This method assumes the lock is being held. */
226234
@throws[InterruptedException]
227235
final protected def awaitReaders(awaiter: Awaiter): Boolean = {
228236
try {

src/test/scala/com/github/yruslan/channel/ChannelSuite.scala

Lines changed: 97 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515

1616
package com.github.yruslan.channel
1717

18-
import java.time.Instant
19-
import java.util.concurrent.{Executors, TimeUnit}
20-
2118
import com.github.yruslan.channel.Channel.select
19+
import com.github.yruslan.channel.mocks.AsyncChannelSpy
2220
import org.scalatest.BeforeAndAfterAll
2321
import org.scalatest.wordspec.AnyWordSpec
2422

23+
import java.time.Instant
24+
import java.util.concurrent.{Executors, TimeUnit}
2525
import scala.collection.mutable.ListBuffer
2626
import scala.concurrent._
2727
import scala.concurrent.duration.{Duration, SECONDS}
@@ -32,7 +32,7 @@ import com.github.yruslan.channel.Channel
3232
class ChannelSuite extends AnyWordSpec with BeforeAndAfterAll {
3333
implicit private var ec: ExecutionContextExecutor = _
3434

35-
private val ex = Executors.newFixedThreadPool(12)
35+
private val ex = Executors.newFixedThreadPool(16)
3636

3737
override def beforeAll(): Unit = {
3838
super.beforeAll()
@@ -111,7 +111,7 @@ class ChannelSuite extends AnyWordSpec with BeforeAndAfterAll {
111111

112112
"closing a synchronous channel should block until the pending message is not received" in {
113113
val start = Instant.now()
114-
val ch = Channel.make[Int]
114+
val ch = Channel.make[Int](0)
115115
var v: Option[Int] = None
116116

117117
Future {
@@ -218,6 +218,90 @@ class ChannelSuite extends AnyWordSpec with BeforeAndAfterAll {
218218
assert(v1 == 0)
219219
assert(v2 == 9999)
220220
}
221+
222+
"send/foreach should handle interrupted thread" in {
223+
val ch = new AsyncChannelSpy[Int](1)
224+
225+
val output = new ListBuffer[Int]
226+
227+
val t1 = createThread {
228+
ch.foreach(a =>
229+
ch.synchronized {
230+
output += a
231+
}
232+
)
233+
}
234+
235+
val t2 = createThread {
236+
ch.foreach(a =>
237+
ch.synchronized {
238+
output += a
239+
}
240+
)
241+
}
242+
243+
t1.start()
244+
t2.start()
245+
246+
ch.send(100)
247+
ch.send(200)
248+
ch.send(300)
249+
250+
t1.interrupt()
251+
252+
ch.send(400)
253+
ch.send(500)
254+
ch.send(600)
255+
ch.send(700)
256+
ch.close()
257+
258+
t2.join(2000)
259+
t1.join(2000)
260+
261+
assert(output.sorted.toList == List(100, 200, 300, 400, 500, 600, 700))
262+
assert(ch.numOfReaders == 0)
263+
assert(ch.numOfWriters == 0)
264+
}
265+
266+
"send/recv should handle interrupted thread" in {
267+
val ch = new AsyncChannelSpy[Int](1)
268+
269+
val output = new ListBuffer[Int]
270+
271+
val t1 = createThread {
272+
ch.send(100)
273+
ch.send(200)
274+
ch.send(300)
275+
}
276+
277+
val t2 = createThread {
278+
ch.send(400)
279+
Thread.sleep(30)
280+
ch.send(500)
281+
ch.send(600)
282+
ch.send(700)
283+
ch.close()
284+
}
285+
286+
t1.start()
287+
t2.start()
288+
289+
output += ch.recv()
290+
291+
t1.interrupt()
292+
293+
ch.foreach(v => output += v)
294+
295+
t2.join(2000)
296+
t1.join(2000)
297+
298+
assert(output.contains(400))
299+
assert(output.contains(500))
300+
assert(output.contains(600))
301+
assert(output.contains(700))
302+
assert(ch.numOfReaders == 0)
303+
assert(ch.numOfWriters == 0)
304+
}
221305
}
222306

223307
"trySend() for sync channels" should {
@@ -709,7 +793,7 @@ class ChannelSuite extends AnyWordSpec with BeforeAndAfterAll {
709793
actions.append(s"S$workerNum$i-")
710794
}
711795
)
712-
if (!k) throw new IllegalArgumentException("sss")
796+
if (!k) throw new IllegalArgumentException("Failing the worker")
713797
}
714798
}
715799

@@ -1080,4 +1164,11 @@ class ChannelSuite extends AnyWordSpec with BeforeAndAfterAll {
10801164
}
10811165
}
10821166

1167+
private def createThread(action: => Unit) = {
1168+
// Creating thread in the Scala 2.11 compatible way.
1169+
// Please do not remove 'new Runnable'
1170+
new Thread(new Runnable {
1171+
def run(): Unit = action
1172+
})
1173+
}
10831174
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (c) 2020 Ruslan Yushchenko
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
package com.github.yruslan.channel.mocks
17+
18+
import com.github.yruslan.channel.AsyncChannel
19+
20+
class AsyncChannelSpy[T](maxCapacity: Int) extends AsyncChannel[T](maxCapacity) {
21+
def numOfReaders: Int = readers
22+
23+
def numOfWriters: Int = writers
24+
}

0 commit comments

Comments
 (0)