Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,30 @@ public void demand(Runnable onContent)
super.demand(new DemandTask(Invocable.getInvocationType(onContent)));
}

private void onContent()
private void onContent(InvocationType invocationType)
{
Permit permit = _remote.acquire();
if (permit.isAllocated())
onPermittedContent(permit);
else
permit.whenAllocated(this::onPermittedContent);
switch (invocationType)
{
case NON_BLOCKING ->
{
Runnable onContent = _onContent.getAndSet(null);
onContent.run();
}
case EITHER ->
{
Runnable onContent = _onContent.getAndSet(null);
Invocable.invokeNonBlocking(onContent);
}
case BLOCKING ->
{
Permit permit = _remote.acquire();
if (permit.isAllocated())
onPermittedContent(permit);
else
permit.whenAllocated(this::onPermittedContent);
}
default -> throw new IllegalStateException(invocationType.name());
}
}

private void onPermittedContent(Permit permit)
Expand All @@ -378,7 +395,7 @@ private DemandTask(InvocationType invocationType)
@Override
public void run()
{
onContent();
onContent(getInvocationType());
}
}
}
Expand All @@ -405,6 +422,30 @@ public void write(boolean last, ByteBuffer byteBuffer, Callback callback)
@Override
public void succeeded()
{
Callback callback = _writeCallback.get();
switch (callback.getInvocationType())
{
case NON_BLOCKING ->
{
_writeCallback.set(null);
callback.succeeded();
}
case EITHER ->
{
_writeCallback.set(null);
Invocable.invokeNonBlocking(callback::succeeded);
}
case BLOCKING ->
{
Permit permit = _remote.acquire();
if (permit.isAllocated())
permittedSuccess(permit);
else
permit.whenAllocated(this::permittedSuccess);
}
default -> throw new IllegalStateException(callback.getInvocationType().name());
}

Permit permit = _remote.acquire();
if (permit.isAllocated())
permittedSuccess(permit);
Expand All @@ -427,11 +468,29 @@ private void permittedSuccess(Permit permit)
@Override
public void failed(Throwable x)
{
Permit permit = _remote.acquire();
if (permit.isAllocated())
permittedFailure(permit, x);
else
permit.whenAllocated(p -> permittedFailure(p, x));
Callback callback = _writeCallback.get();
switch (callback.getInvocationType())
{
case NON_BLOCKING ->
{
_writeCallback.set(null);
callback.failed(x);
}
case EITHER ->
{
_writeCallback.set(null);
Invocable.invokeNonBlocking(() -> callback.failed(x));
}
case BLOCKING ->
{
Permit permit = _remote.acquire();
if (permit.isAllocated())
permittedFailure(permit, x);
else
permit.whenAllocated(p -> permittedFailure(p, x));
}
default -> throw new IllegalStateException(callback.getInvocationType().name());
}
}

private void permittedFailure(Permit permit, Throwable x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package org.eclipse.jetty.server.handler;

import java.net.Socket;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -23,6 +25,7 @@
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.io.Content;
import org.eclipse.jetty.io.EofException;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.LocalConnector;
Expand All @@ -31,8 +34,12 @@
import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.util.Blocker;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.thread.Invocable;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
Expand All @@ -41,6 +48,7 @@
import static org.awaitility.Awaitility.await;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand Down Expand Up @@ -378,4 +386,133 @@ public void run()

await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0));
}

@Test
public void testBlockingWrite() throws Exception
{
ThreadLimitHandler handler = new ThreadLimitHandler("Forwarded");
handler.setThreadLimit(1);

CompletableFuture<Throwable> future = new CompletableFuture<>();
AtomicReference<Thread> threadReference = new AtomicReference<>();
handler.setHandler(new Handler.Abstract()
{
@Override
public boolean handle(Request request, Response response, Callback callback)
{
try
{
threadReference.set(Thread.currentThread());
while (true)
{
try (Blocker.Callback blocking = Blocker.callback())
{
response.write(false, BufferUtil.toBuffer("x".repeat(1024)), blocking);
blocking.block();
}
}
}
catch (Exception e)
{
future.complete(e);
callback.failed(e);
}
return true;
}
});
_server.setHandler(handler);
_server.start();

Socket client = new Socket("127.0.0.1", _connector.getLocalPort());
client.getOutputStream().write(("POST /" + " HTTP/1.0\r\nForwarded: for=1.2.3.4\r\nContent-Length: 0\r\n\r\n").getBytes());

// Validate the first 100 bytes of the response.
byte[] bytes = client.getInputStream().readNBytes(100);
String utf8String = StringUtil.toUTF8String(bytes, 0, bytes.length);
assertThat(utf8String, Matchers.containsString(" 200 OK"));
assertThat(utf8String, Matchers.containsString("xxxxx"));

// Delay to let the write side get TCP blocked, then close the connection.
// This ensures the server write callback will be failed by a different thread.
await().atMost(5, TimeUnit.SECONDS).pollDelay(1, TimeUnit.SECONDS).until(() ->
{
Thread thread = threadReference.get();
return thread != null && thread.getState() == Thread.State.WAITING;
});
client.close();

// The blocker callback will be failed by a different thread than the handling thread.
// This will test that the ThreadLimitHandler allows the NON_BLOCKING invocation type of Blocker.Callback
// to temporarily exceed the thread limit allowing it to unblock the handling thread.
Throwable throwable = future.get(5, TimeUnit.SECONDS);
assertThat(throwable, instanceOf(EofException.class));
}

@Test
public void testBlockingRead() throws Exception
{
ThreadLimitHandler handler = new ThreadLimitHandler("Forwarded");
handler.setThreadLimit(1);

AtomicInteger count = new AtomicInteger();
CountDownLatch awaitingMoreContent = new CountDownLatch(1);
handler.setHandler(new Handler.Abstract()
{
@Override
public boolean handle(Request request, Response response, Callback callback)
{
try
{
while (true)
{
Content.Chunk chunk = request.read();
if (chunk == null)
{
// Block waiting for the next content.
CountDownLatch latch = new CountDownLatch(1);
request.demand(Invocable.from(InvocationType.NON_BLOCKING, latch::countDown));
awaitingMoreContent.countDown();
latch.await();
continue;
}

if (Content.Chunk.isFailure(chunk))
throw chunk.getFailure();

count.addAndGet(chunk.remaining());
chunk.release();

if (chunk.isLast())
break;
}
}
catch (Throwable t)
{
callback.failed(t);
}

callback.succeeded();
return true;
}
});
_server.setHandler(handler);
_server.start();

Socket client = new Socket("127.0.0.1", _connector.getLocalPort());
client.getOutputStream().write(("POST /" + " HTTP/1.0\r\nForwarded: for=1.2.3.4\r\nContent-Length: 128\r\n\r\n").getBytes());

// Wait until the server has read a null chunk and has demanded more content.
assertTrue(awaitingMoreContent.await(5, TimeUnit.SECONDS));

// Write some content to the server to unblock it.
client.getOutputStream().write("x".repeat(128).getBytes());

// Assert that the server read all the content.
await().atMost(Duration.ofSeconds(5)).untilAtomic(count, is(128));

// Assert we got a 200 response.
String response = IO.toString(client.getInputStream());
assertThat(response, Matchers.containsString("200 OK"));
client.close();
}
}