diff --git a/core/src/main/java/io/undertow/UndertowOptions.java b/core/src/main/java/io/undertow/UndertowOptions.java index b8d0b2553c..2a28a67a3b 100644 --- a/core/src/main/java/io/undertow/UndertowOptions.java +++ b/core/src/main/java/io/undertow/UndertowOptions.java @@ -342,10 +342,15 @@ public class UndertowOptions { */ public static final Option MAX_CACHED_HEADER_SIZE = Option.simple(UndertowOptions.class, "MAX_CACHED_HEADER_SIZE", Integer.class); + /** + * Default value of {@link #HTTP_HEADERS_CACHE_SIZE} option. + */ public static final int DEFAULT_HTTP_HEADERS_CACHE_SIZE = 15; /** - * The maximum number of headers that are cached per connection. Defaults to 15. If this is set to zero the cache is disabled. + * The maximum number of headers that are cached per connection. If this is set to zero the cache is disabled. + *

+ * Defaults to {@link #DEFAULT_HTTP_HEADERS_CACHE_SIZE} */ public static final Option HTTP_HEADERS_CACHE_SIZE = Option.simple(UndertowOptions.class, "HTTP_HEADERS_CACHE_SIZE", Integer.class); diff --git a/core/src/main/java/io/undertow/conduits/ChunkedStreamSinkConduit.java b/core/src/main/java/io/undertow/conduits/ChunkedStreamSinkConduit.java index ec8a8e3a61..2aa6cb816f 100644 --- a/core/src/main/java/io/undertow/conduits/ChunkedStreamSinkConduit.java +++ b/core/src/main/java/io/undertow/conduits/ChunkedStreamSinkConduit.java @@ -35,6 +35,7 @@ import io.undertow.util.HeaderValues; import io.undertow.util.Headers; import io.undertow.util.ImmediatePooledByteBuffer; +import org.xnio.Buffers; import org.xnio.IoUtils; import io.undertow.connector.ByteBufferPool; import io.undertow.connector.PooledByteBuffer; @@ -128,6 +129,109 @@ public int write(final ByteBuffer src) throws IOException { return doWrite(src); } + long doWrite(final ByteBuffer[] srcs, int offset, int length) throws IOException { + if (anyAreSet(state, FLAG_WRITES_SHUTDOWN)) { + throw new ClosedChannelException(); + } + // Write as many buffers as possible without a chunk-size overflowing an integer. + long totalRemaining = 0; + for (int i = 0; i < length; i++) { + ByteBuffer buf = srcs[i + offset]; + int remaining = buf.remaining(); + if (totalRemaining + remaining > Integer.MAX_VALUE) { + // Avoid producing chunks too large for clients by reducing the number of buffers + // until total remaining fits within a 32-bit signed integer value. This is safe + // because a single java ByteBuffer has a capacity represented by an integer. + length = i; + break; + } + totalRemaining += remaining; + } + if(totalRemaining == 0) { + return 0; + } + int remaining = (int) totalRemaining; + this.state |= FLAG_FIRST_DATA_WRITTEN; + int oldLimit = srcs[length - 1].limit(); + boolean dataRemaining = false; //set to true if there is data in src that still needs to be written out + if (chunkleft == 0 && !chunkingSepBuffer.hasRemaining()) { + chunkingBuffer.clear(); + putIntAsHexString(chunkingBuffer, remaining); + chunkingBuffer.put(CRLF); + chunkingBuffer.flip(); + chunkingSepBuffer.clear(); + chunkingSepBuffer.put(CRLF); + chunkingSepBuffer.flip(); + state |= FLAG_WRITTEN_FIRST_CHUNK; + chunkleft = remaining; + } else { + int maxRemaining = chunkleft; + for (int i = 0; i < length; i++) { + ByteBuffer buf = srcs[offset + i]; + int bufRemaining = buf.remaining(); + if (bufRemaining >= maxRemaining) { + length = i + 1; + oldLimit = buf.limit(); + dataRemaining = true; + buf.limit(buf.position() + maxRemaining); + break; + } + maxRemaining -= bufRemaining; + } + } + try { + int chunkingSize = chunkingBuffer.remaining(); + int chunkingSepSize = chunkingSepBuffer.remaining(); + if (chunkingSize > 0 || chunkingSepSize > 0 || lastChunkBuffer != null) { + int originalRemaining = (int) Buffers.remaining(srcs, offset, length); + long result; + if (lastChunkBuffer == null || dataRemaining) { + // chunkingBuffer + // srcs (taking into account offset+length) + // chunkingSepBuffer + final ByteBuffer[] buf = new ByteBuffer[2 + length]; + buf[0] = chunkingBuffer; + System.arraycopy(srcs, offset , buf, 1, length); + buf[length + 1] = chunkingSepBuffer; + result = next.write(buf, 0, buf.length); + } else { + // chunkingBuffer + // srcs (taking into account offset+length) + // lastChunkBuffer + final ByteBuffer[] buf = new ByteBuffer[2 + length]; + buf[0] = chunkingBuffer; + System.arraycopy(srcs, offset , buf, 1, length); + buf[length + 1] = lastChunkBuffer.getBuffer(); + if (anyAreSet(state, CONF_FLAG_PASS_CLOSE)) { + result = next.writeFinal(buf, 0, buf.length); + } else { + result = next.write(buf, 0, buf.length); + } + if (Buffers.remaining(srcs, offset, length) == 0) { + state |= FLAG_WRITES_SHUTDOWN; + } + if (!lastChunkBuffer.getBuffer().hasRemaining()) { + state |= FLAG_NEXT_SHUTDOWN; + lastChunkBuffer.close(); + } + } + int srcWritten = originalRemaining - (int) Buffers.remaining(srcs, offset, length); + chunkleft -= srcWritten; + if (result < chunkingSize) { + return 0; + } else { + return srcWritten; + } + } else { + long result = next.write(srcs, offset, length); + chunkleft -= result; + return result; + + } + } finally { + srcs[length - 1].limit(oldLimit); + } + } int doWrite(final ByteBuffer src) throws IOException { if (anyAreSet(state, FLAG_WRITES_SHUTDOWN)) { @@ -195,7 +299,6 @@ int doWrite(final ByteBuffer src) throws IOException { } finally { src.limit(oldLimit); } - } @Override @@ -217,13 +320,7 @@ public void truncateWrites() throws IOException { @Override public long write(final ByteBuffer[] srcs, final int offset, final int length) throws IOException { - for (int i = 0; i < length; i++) { - ByteBuffer srcBuffer = srcs[offset + i]; - if (srcBuffer.hasRemaining()) { - return write(srcBuffer); - } - } - return 0; + return doWrite(srcs, offset, length); } @Override @@ -382,7 +479,7 @@ private void createLastChunk(final boolean writeFinal) throws UnsupportedEncodin lastChunkBuffer.put(CRLF); } //horrible hack - //there is a situation where we can get a buffer leak here if the connection is terminated abnormaly + //there is a situation where we can get a buffer leak here if the connection is terminated abnormally //this should be fixed once this channel has its lifecycle tied to the connection, same as fixed length lastChunkBuffer.flip(); ByteBuffer data = ByteBuffer.allocate(lastChunkBuffer.remaining()); diff --git a/core/src/main/java/io/undertow/server/handlers/resource/URLResource.java b/core/src/main/java/io/undertow/server/handlers/resource/URLResource.java index 089a5eb4c8..b906141040 100644 --- a/core/src/main/java/io/undertow/server/handlers/resource/URLResource.java +++ b/core/src/main/java/io/undertow/server/handlers/resource/URLResource.java @@ -27,6 +27,7 @@ import java.net.URLConnection; import java.nio.ByteBuffer; import java.nio.file.DirectoryStream; +import java.nio.file.FileSystemNotFoundException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; @@ -299,8 +300,15 @@ public Path getFilePath() { } catch (URISyntaxException e) { return null; } + } else { + //deffer to Paths/FS --> ServiceLoader for java.nio.file.spi.FileSystemProvider + //NOTE: FS has to be installed: java.nio.file.FileSystems#newFileSystem + try { + return Paths.get(url.toURI()); + } catch(FileSystemNotFoundException|IllegalArgumentException|URISyntaxException e) { + return null; + } } - return null; } @Override diff --git a/core/src/main/java/io/undertow/server/protocol/framed/AbstractFramedStreamSourceChannel.java b/core/src/main/java/io/undertow/server/protocol/framed/AbstractFramedStreamSourceChannel.java index 6ed14535d1..c3d572ea32 100644 --- a/core/src/main/java/io/undertow/server/protocol/framed/AbstractFramedStreamSourceChannel.java +++ b/core/src/main/java/io/undertow/server/protocol/framed/AbstractFramedStreamSourceChannel.java @@ -122,21 +122,31 @@ public long transferTo(long position, long count, FileChannel target) throws IOE return 0; } try { + final PooledByteBuffer localData = data; if (frameDataRemaining == 0 && anyAreSet(state, STATE_LAST_FRAME)) { synchronized (lock) { state |= STATE_RETURNED_MINUS_ONE; return -1; } - } else if (data != null) { - int old = data.getBuffer().limit(); + } else if (localData != null) { try { - if (count < data.getBuffer().remaining()) { - data.getBuffer().limit((int) (data.getBuffer().position() + count)); + final int old = localData.getBuffer().limit(); + try { + if (count < localData.getBuffer().remaining()) { + localData.getBuffer().limit((int) (localData.getBuffer().position() + count)); + } + return target.write(localData.getBuffer(), position); + } finally { + localData.getBuffer().limit(old); + decrementFrameDataRemaining(); + } + } catch (IllegalStateException e) { + // NPE should be covered. ISE in case of closed buffer + if (anyAreSet(state, STATE_DONE | STATE_CLOSED | STATE_STREAM_BROKEN)) { + return -1; + } else { + throw e; } - return target.write(data.getBuffer(), position); - } finally { - data.getBuffer().limit(old); - decrementFrameDataRemaining(); } } return 0; @@ -146,7 +156,8 @@ public long transferTo(long position, long count, FileChannel target) throws IOE } private void decrementFrameDataRemaining() { - if(!data.getBuffer().hasRemaining()) { + final PooledByteBuffer localData = data; + if(localData != null && !localData.getBuffer().hasRemaining()) { frameDataRemaining -= currentDataOriginalSize; } } @@ -162,36 +173,44 @@ public long transferTo(long count, ByteBuffer throughBuffer, StreamSinkChannel s return 0; } try { + final PooledByteBuffer localData = data; if (frameDataRemaining == 0 && anyAreSet(state, STATE_LAST_FRAME)) { synchronized (lock) { state |= STATE_RETURNED_MINUS_ONE; return -1; } - } else if (data != null && data.getBuffer().hasRemaining()) { - int old = data.getBuffer().limit(); + } else if (localData != null && localData.getBuffer().hasRemaining()) { + int old = localData.getBuffer().limit(); try { - if (count < data.getBuffer().remaining()) { - data.getBuffer().limit((int) (data.getBuffer().position() + count)); + if (count < localData.getBuffer().remaining()) { + localData.getBuffer().limit((int) (localData.getBuffer().position() + count)); } - int written = streamSinkChannel.write(data.getBuffer()); - if(data.getBuffer().hasRemaining()) { + int written = streamSinkChannel.write(localData.getBuffer()); + if(localData.getBuffer().hasRemaining()) { //we can still add more data //stick it it throughbuffer, otherwise transfer code will continue to attempt to use this method throughBuffer.clear(); - Buffers.copy(throughBuffer, data.getBuffer()); + Buffers.copy(throughBuffer, localData.getBuffer()); throughBuffer.flip(); } else { throughBuffer.position(throughBuffer.limit()); } return written; } finally { - data.getBuffer().limit(old); + localData.getBuffer().limit(old); decrementFrameDataRemaining(); } } else { throughBuffer.position(throughBuffer.limit()); } return 0; + } catch (IllegalStateException e) { + // NPE should be covered. ISE in case of closed buffer + if (anyAreSet(state, STATE_DONE | STATE_CLOSED | STATE_STREAM_BROKEN)) { + return -1; + } else { + throw e; + } } finally { exitRead(); } @@ -589,29 +608,29 @@ private void beforeRead() throws IOException { } private void exitRead() throws IOException { - if (data != null && !data.getBuffer().hasRemaining()) { - data.close(); - data = null; - } - if (frameDataRemaining == 0) { - try { - synchronized (lock) { + synchronized (lock) { + if (data != null && !data.getBuffer().hasRemaining()) { + data.close(); + data = null; + } + if (frameDataRemaining == 0) { + try { readFrameCount++; if (pendingFrameData.isEmpty()) { if (anyAreSet(state, STATE_RETURNED_MINUS_ONE)) { state |= STATE_DONE; complete(); close(); - } else if(anyAreSet(state, STATE_LAST_FRAME)) { + } else if (anyAreSet(state, STATE_LAST_FRAME)) { state |= STATE_WAITNG_MINUS_ONE; } else { waitingForFrame = true; } } - } - } finally { - if (pendingFrameData.isEmpty()) { - framedChannel.notifyFrameReadComplete(this); + } finally { + if (pendingFrameData.isEmpty()) { + framedChannel.notifyFrameReadComplete(this); + } } } } diff --git a/core/src/main/java/io/undertow/server/protocol/http2/Http2OpenListener.java b/core/src/main/java/io/undertow/server/protocol/http2/Http2OpenListener.java index 5d0015fda8..1a003db075 100644 --- a/core/src/main/java/io/undertow/server/protocol/http2/Http2OpenListener.java +++ b/core/src/main/java/io/undertow/server/protocol/http2/Http2OpenListener.java @@ -23,6 +23,8 @@ import io.undertow.UndertowOptions; import io.undertow.conduits.BytesReceivedStreamSourceConduit; import io.undertow.conduits.BytesSentStreamSinkConduit; +import io.undertow.conduits.ReadTimeoutStreamSourceConduit; +import io.undertow.conduits.WriteTimeoutStreamSinkConduit; import io.undertow.protocols.http2.Http2Channel; import io.undertow.server.ConnectorStatistics; import io.undertow.server.ConnectorStatisticsImpl; @@ -32,12 +34,14 @@ import org.xnio.ChannelListener; import org.xnio.IoUtils; import org.xnio.OptionMap; +import org.xnio.Options; import io.undertow.connector.ByteBufferPool; import io.undertow.connector.PooledByteBuffer; import org.xnio.Pool; import org.xnio.StreamConnection; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.Collections; import java.util.Set; @@ -120,6 +124,22 @@ public void handleEvent(final StreamConnection channel, PooledByteBuffer buffer) if (idleTimeout != null && idleTimeout > 0) { http2Channel.setIdleTimeout(idleTimeout); } + try { + Integer readTimeout = channel.getOption(Options.READ_TIMEOUT); + if (readTimeout != null && readTimeout > 0) { + channel.getSourceChannel().setConduit(new ReadTimeoutStreamSourceConduit(channel.getSourceChannel().getConduit(), channel, this)); + } + } catch (IOException e) { + UndertowLogger.REQUEST_IO_LOGGER.ioException(e); + } + try { + Integer writeTimeout = channel.getOption(Options.WRITE_TIMEOUT); + if (writeTimeout != null && writeTimeout > 0) { + channel.getSinkChannel().setConduit(new WriteTimeoutStreamSinkConduit(channel.getSinkChannel().getConduit(), channel, this)); + } + } catch (IOException e) { + UndertowLogger.REQUEST_IO_LOGGER.ioException(e); + } if(statisticsEnabled) { channel.getSinkChannel().setConduit(new BytesSentStreamSinkConduit(channel.getSinkChannel().getConduit(), connectorStatistics.sentAccumulator())); channel.getSourceChannel().setConduit(new BytesReceivedStreamSourceConduit(channel.getSourceChannel().getConduit(), connectorStatistics.receivedAccumulator())); diff --git a/core/src/test/java/io/undertow/client/http/Http2ReadTimeoutTestCase.java b/core/src/test/java/io/undertow/client/http/Http2ReadTimeoutTestCase.java new file mode 100644 index 0000000000..d91e073ebe --- /dev/null +++ b/core/src/test/java/io/undertow/client/http/Http2ReadTimeoutTestCase.java @@ -0,0 +1,442 @@ +/* + * JBoss, Home of Professional Open Source. + * Copyright 2022 Red Hat, Inc., and individual contributors + * as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.undertow.client.http; + +import io.undertow.Undertow; +import io.undertow.UndertowOptions; +import io.undertow.client.ClientCallback; +import io.undertow.client.ClientConnection; +import io.undertow.client.ClientExchange; +import io.undertow.client.ClientRequest; +import io.undertow.client.ClientResponse; +import io.undertow.client.UndertowClient; +import io.undertow.io.IoCallback; +import io.undertow.io.Receiver.ErrorCallback; +import io.undertow.io.Receiver.PartialBytesCallback; +import io.undertow.io.Sender; +import io.undertow.protocols.ssl.UndertowXnioSsl; +import io.undertow.server.HttpHandler; +import io.undertow.server.HttpServerExchange; +import io.undertow.server.handlers.PathHandler; +import io.undertow.testutils.DefaultServer; +import io.undertow.testutils.ProxyIgnore; +import io.undertow.testutils.StopServerWithExternalWorkerUtils; +import io.undertow.util.AttachmentKey; +import io.undertow.util.Headers; +import io.undertow.util.Methods; +import io.undertow.util.Protocols; +import io.undertow.util.StatusCodes; +import io.undertow.util.StringReadChannelListener; +import io.undertow.util.WorkerUtils; +import java.io.IOException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.xnio.ChannelListener; +import org.xnio.ChannelListeners; +import org.xnio.IoUtils; +import org.xnio.OptionMap; +import org.xnio.Options; +import org.xnio.Xnio; +import org.xnio.XnioWorker; +import org.xnio.channels.StreamSinkChannel; + +/** + *

Test class for the READ_TIMEOUT in the HTTP2 listener.

+ * + * @author rmartinc + */ +@RunWith(DefaultServer.class) +@ProxyIgnore +public class Http2ReadTimeoutTestCase { + + private static final String message = "012345678901234567890123456789"; + public static final String MESSAGE = "/message"; + + private static final int READ_TIMEOUT = 5000; + private static final OptionMap DEFAULT_OPTIONS; + private static XnioWorker worker; + private static Undertow server; + private static URL ADDRESS; + + private static final AttachmentKey RESPONSE_BODY = AttachmentKey.create(String.class); + + static { + final OptionMap.Builder builder = OptionMap.builder() + .set(Options.WORKER_IO_THREADS, 8) + .set(Options.TCP_NODELAY, true) + .set(Options.KEEP_ALIVE, true) + .set(Options.WORKER_NAME, "Client"); + + DEFAULT_OPTIONS = builder.getMap(); + } + + @BeforeClass + public static void beforeClass() throws IOException { + + int port = DefaultServer.getHostPort("default"); + + final PathHandler path = new PathHandler() + .addExactPath(MESSAGE, new HttpHandler() { + + /** + * The method just returns the size of the data received. + */ + @Override + public void handleRequest(HttpServerExchange exchange) throws Exception { + final boolean blocking = Boolean.parseBoolean(exchange.getQueryParameters().get("blocking").getFirst()); + if (blocking) { + if (exchange.isInIoThread()) { + // do blocking + exchange.startBlocking(); + exchange.dispatch(this); + return; + } + } + exchange.setStatusCode(StatusCodes.OK); + ReceiverCallback callback = new ReceiverCallback(exchange.getResponseSender()); + exchange.getRequestReceiver().receivePartialBytes(callback, callback); + } + }); + + server = Undertow.builder() + .setByteBufferPool(DefaultServer.getBufferPool()) + .addHttpsListener(port + 1, DefaultServer.getHostAddress("default"), DefaultServer.getServerSslContext()) + .setServerOption(UndertowOptions.ENABLE_HTTP2, true) + .setSocketOption(Options.READ_TIMEOUT, READ_TIMEOUT) + .setSocketOption(Options.REUSE_ADDRESSES, true) + .setHandler(new HttpHandler() { + @Override + public void handleRequest(HttpServerExchange exchange) throws Exception { + if (!exchange.getProtocol().equals(Protocols.HTTP_2_0)) { + throw new RuntimeException("Not HTTP/2"); + } + path.handleRequest(exchange); + } + }) + .build(); + + server.start(); + ADDRESS = new URL("https://" + DefaultServer.getHostAddress() + ":" + (port + 1)); + + // Create xnio worker + final Xnio xnio = Xnio.getInstance(); + final XnioWorker xnioWorker = xnio.createWorker(null, DEFAULT_OPTIONS); + worker = xnioWorker; + } + + @AfterClass + public static void afterClass() { + if (server != null) { + server.stop(); + } + if (worker != null) { + StopServerWithExternalWorkerUtils.stopWorker(worker); + } + } + + @Test + public void testBlockingSuccess() throws Exception { + // test blocking stopping writes less that READ_TIMEOUT + test(true, message.length() * 3, READ_TIMEOUT / 2, false, false); + } + + @Test + public void testNonBlockingSuccess() throws Exception { + // test non-blocking stopping writes less that READ_TIMEOUT + test(false, message.length() * 3, READ_TIMEOUT / 2, false, false); + } + + @Test + public void testBlockingException() throws Exception { + // test blocking stopping writes more that READ_TIMEOUT => exception expected + test(true, message.length() * 3, READ_TIMEOUT * 2, false, true); + } + + @Test + public void testNonBlockingException() throws Exception { + // test non-blocking stopping writes more that READ_TIMEOUT => exception expected + test(false, message.length() * 3, READ_TIMEOUT * 2, false, true); + } + + @Test + public void testBlockingRepetitiveSuccess() throws Exception { + // test blocking repetitive to check that tasks are updated + test(true, message.length() * 7, READ_TIMEOUT / 5, true, false); + } + + @Test + public void testNonBlockingRepetitiveSuccess() throws Exception { + // test non-blocking repetitive to check that tasks are updated + test(false, message.length() * 7, READ_TIMEOUT / 5, true, false); + } + + /** + * The internal test method. The client sends a POST but it starts to write + * the post data after a timeout. If repetitiveTimeout is true the client + * writes in chunks (message size) waiting timeout millis between every + * chunk. The test waits a max time of READ_TIME * 2. + * + * @param blocking true use blocking, false use non-blocking + * @param size The size of the message to send + * @param timeout The initial timeout before writing to the server + * @param repetitiveTimeout If the timeout should be done repetitively + * @param expectedException true if exception is expected, false if not + * @throws Exception Some error + */ + private void test(final boolean blocking, final int size, final int timeout, final boolean repetitiveTimeout, boolean expectedException) throws Exception { + // create the client with a small window size + final UndertowClient client = UndertowClient.getInstance(); + final ClientConnection connection = client.connect(ADDRESS.toURI(), worker, + new UndertowXnioSsl(worker.getXnio(), OptionMap.EMPTY, DefaultServer.getClientSSLContext()), + DefaultServer.getBufferPool(), + OptionMap.builder() + .set(UndertowOptions.ENABLE_HTTP2, true) + .getMap() + ).get(); + + final List responses = new CopyOnWriteArrayList<>(); + final CountDownLatch latch = new CountDownLatch(1); + + try { + long startTime = System.currentTimeMillis(); + connection.getIoThread().execute(() -> { + final ClientRequest request = new ClientRequest().setMethod(Methods.POST).setPath(MESSAGE + "?blocking=" + blocking); + request.getRequestHeaders().put(Headers.HOST, DefaultServer.getHostAddress()); + request.getRequestHeaders().put(Headers.CONTENT_LENGTH, size); + connection.sendRequest(request, createClientCallback(size, timeout, repetitiveTimeout, responses, latch)); + }); + + Assert.assertTrue("Response did not come in the specified time", latch.await(READ_TIMEOUT * 2, TimeUnit.MILLISECONDS)); + Assert.assertEquals("Incorrect number of responses returned", 1, responses.size()); + ClientResponseOrException response = responses.iterator().next(); + if (expectedException) { + Assert.assertFalse("Expected exception but was a response", response.isResponse()); + Assert.assertTrue("The timeout was not triggered at READ_TIMEOUT", System.currentTimeMillis() - startTime < timeout); + } else { + Assert.assertTrue("Expected response but was a exception", response.isResponse()); + Assert.assertEquals("Incorrect status code", StatusCodes.OK, response.getResponse().getResponseCode()); + final String body = response.getResponse().getAttachment(RESPONSE_BODY); + Assert.assertEquals("Unexpected size received", size, Integer.parseInt(body)); + } + } finally { + IoUtils.safeClose(connection); + } + } + + private ClientCallback createClientCallback(final int size, final int timeout, + final boolean repetitiveTimeout, final List responses, final CountDownLatch latch) { + return new ClientCallback() { + @Override + public void completed(ClientExchange result) { + WriteChannelListener writeListener = new WriteChannelListener(result, size, repetitiveTimeout? timeout : 0); + result.getRequestChannel().suspendWrites(); + result.getRequestChannel().getWriteSetter().set(writeListener); + if (timeout > 0) { + // if timeout starts the listener after it + WorkerUtils.executeAfter(result.getRequestChannel().getIoThread(), writeListener, timeout, TimeUnit.MILLISECONDS); + } else { + // no timeout, just start writing + writeListener.run(); + } + result.setResponseListener(new ClientCallback() { + + @Override + public void completed(ClientExchange result) { + responses.add(new ClientResponseOrException(result.getResponse())); + new StringReadChannelListener(DefaultServer.getBufferPool()) { + + @Override + protected void stringDone(String string) { + result.getResponse().putAttachment(RESPONSE_BODY, string); + latch.countDown(); + } + + @Override + protected void error(IOException e) { + responses.add(new ClientResponseOrException(e)); + latch.countDown(); + } + }.setup(result.getResponseChannel()); + } + + @Override + public void failed(IOException e) { + responses.add(new ClientResponseOrException(e)); + latch.countDown(); + } + }); + } + + @Override + public void failed(IOException e) { + responses.add(new ClientResponseOrException(e)); + latch.countDown(); + } + }; + } + + /** + * A partial bytes callback that counts the bytes received and writes the + * final size back as the response. + */ + private static class ReceiverCallback implements PartialBytesCallback, ErrorCallback { + + private final Sender sender; + private int size; + + ReceiverCallback(Sender sender) { + this.sender = sender; + size = 0; + } + + @Override + public void handle(HttpServerExchange exchange, byte[] message, boolean last) { + size += message.length; + if (last) { + sender.send(Integer.toString(size)); + } + } + + @Override + public void error(HttpServerExchange exchange, IOException e) { + IoCallback.END_EXCHANGE.onException(exchange, sender, e); + } + } + + /** + * A channel listener that writes the message the times needed until + * size bytes are sent. If a timeout is passed the listener + * uses task to write the data waiting the timeout between writing every + * message. + */ + private class WriteChannelListener implements Runnable, ChannelListener { + + private final ClientExchange result; + private int size; + private final int timeout; + private final ByteBuffer buffer; + + WriteChannelListener(ClientExchange result, int size, int timeout) { + this.result = result; + this.size = size; + this.timeout = timeout; + this.buffer = ByteBuffer.wrap(message.getBytes()); + wrapBuffer(); + } + + @Override + public void run() { + this.handleEvent(result.getRequestChannel()); + } + + @Override + public void handleEvent(StreamSinkChannel channel) { + try { + int c; + do { + c = channel.write(buffer); + size = size - c; + if (!buffer.hasRemaining() && size > 0) { + wrapBuffer(); + if (timeout > 0) { + if (!channel.flush()) { + // force resume writes + c = 0; + } + break; + } + } + } while (c > 0); + + if (size == 0) { + writeDone(channel); + } else if (c > 0 && timeout > 0) { + channel.suspendWrites(); + WorkerUtils.executeAfter(channel.getIoThread(), this, timeout, TimeUnit.MILLISECONDS); + } else if (!channel.isWriteResumed()) { + channel.resumeWrites(); + } + } catch (IOException e) { + IoUtils.safeClose(channel); + } + } + + private void wrapBuffer() { + buffer.position(0); + if (size < buffer.capacity()) { + buffer.limit(size); + } else { + buffer.limit(buffer.capacity()); + } + } + + private void writeDone(final StreamSinkChannel channel) { + try { + channel.shutdownWrites(); + if (!channel.flush()) { + channel.getWriteSetter().set(ChannelListeners.flushingChannelListener( + c -> IoUtils.safeClose(c), + ChannelListeners.closingChannelExceptionHandler())); + channel.resumeWrites(); + } + } catch (IOException e) { + IoUtils.safeClose(channel); + } + } + } + + /** + * Class to store the client response or the exception. + */ + private class ClientResponseOrException { + private final ClientResponse response; + private final IOException exception; + + ClientResponseOrException(ClientResponse response) { + this.response = response; + this.exception = null; + } + + ClientResponseOrException(IOException exception) { + this.response = null; + this.exception = exception; + } + + public ClientResponse getResponse() { + return response; + } + + public IOException getException() { + return exception; + } + + public boolean isResponse() { + return response != null; + } + } +} diff --git a/servlet/src/main/java/io/undertow/servlet/spec/ServletOutputStreamImpl.java b/servlet/src/main/java/io/undertow/servlet/spec/ServletOutputStreamImpl.java index 1efff047d6..92efa43a98 100644 --- a/servlet/src/main/java/io/undertow/servlet/spec/ServletOutputStreamImpl.java +++ b/servlet/src/main/java/io/undertow/servlet/spec/ServletOutputStreamImpl.java @@ -638,7 +638,7 @@ public void close() throws IOException { if (buffer == null && (contentLength == null || !Methods.HEAD_STRING.equals(servletRequestContext.getOriginalRequest().getMethod()))) { servletRequestContext.getExchange().getResponseHeaders().put(Headers.CONTENT_LENGTH, "0"); } else if (buffer != null && contentLength == null) { - servletRequestContext.getExchange().getResponseHeaders().put(Headers.CONTENT_LENGTH, Integer.toString(buffer.position())); + servletRequestContext.getExchange().getResponseHeaders().put(Headers.CONTENT_LENGTH, buffer.position()); } } } diff --git a/servlet/src/test/java/io/undertow/servlet/test/ProxyForwardedTestCase.java b/servlet/src/test/java/io/undertow/servlet/test/ProxyForwardedTestCase.java index 399d8a0ad4..c89de53479 100644 --- a/servlet/src/test/java/io/undertow/servlet/test/ProxyForwardedTestCase.java +++ b/servlet/src/test/java/io/undertow/servlet/test/ProxyForwardedTestCase.java @@ -44,7 +44,6 @@ import jakarta.servlet.ServletException; import java.io.IOException; import java.net.InetSocketAddress; -import java.net.Socket; import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -56,11 +55,9 @@ @RunWith(DefaultServer.class) @ProxyIgnore public class ProxyForwardedTestCase { - protected static int PORT; @BeforeClass public static void setup() throws ServletException { - PORT = DefaultServer.getHostPort("default"); final PathHandler root = new PathHandler(); final ServletContainer container = ServletContainer.Factory.newInstance(); @@ -98,13 +95,12 @@ public void testForwardedHandler() throws IOException { HttpEntity entity = result.getEntity(); String results = EntityUtils.toString(entity); Map map = convertWithStream(results); - Socket socket = new Socket(); - socket.connect(new InetSocketAddress(DefaultServer.getHostAddress(), PORT)); + InetSocketAddress serverAddress = DefaultServer.getDefaultServerAddress(); Assert.assertEquals(StatusCodes.OK, result.getStatusLine().getStatusCode()); - Assert.assertEquals(socket.getLocalAddress().getHostAddress(), map.get(GenericServletConstants.LOCAL_ADDR)); - Assert.assertEquals(socket.getLocalAddress().getHostName(), map.get(GenericServletConstants.LOCAL_NAME)); - Assert.assertEquals(PORT, Integer.parseInt(map.get(GenericServletConstants.LOCAL_PORT))); + Assert.assertEquals(serverAddress.getAddress().getHostAddress(), map.get(GenericServletConstants.LOCAL_ADDR)); + Assert.assertEquals(serverAddress.getAddress().getHostName(), map.get(GenericServletConstants.LOCAL_NAME)); + Assert.assertEquals(serverAddress.getPort(), Integer.parseInt(map.get(GenericServletConstants.LOCAL_PORT))); Assert.assertEquals("192.0.2.10", map.get(GenericServletConstants.SERVER_NAME)); Assert.assertEquals("8888", map.get(GenericServletConstants.SERVER_PORT)); Assert.assertEquals("192.0.2.43", map.get(GenericServletConstants.REMOTE_ADDR)); diff --git a/servlet/src/test/java/io/undertow/servlet/test/ProxyXForwardedTestCase.java b/servlet/src/test/java/io/undertow/servlet/test/ProxyXForwardedTestCase.java index b67367c8dc..48066f8a45 100644 --- a/servlet/src/test/java/io/undertow/servlet/test/ProxyXForwardedTestCase.java +++ b/servlet/src/test/java/io/undertow/servlet/test/ProxyXForwardedTestCase.java @@ -44,7 +44,6 @@ import jakarta.servlet.ServletException; import java.io.IOException; import java.net.InetSocketAddress; -import java.net.Socket; import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -56,11 +55,9 @@ @RunWith(DefaultServer.class) @ProxyIgnore public class ProxyXForwardedTestCase { - protected static int PORT; @BeforeClass public static void setup() throws ServletException { - PORT = DefaultServer.getHostPort("default"); final PathHandler root = new PathHandler(); final ServletContainer container = ServletContainer.Factory.newInstance(); @@ -96,13 +93,12 @@ public void testProxyPeerHandler() throws IOException, ServletException { HttpEntity entity = result.getEntity(); String results = EntityUtils.toString(entity); Map map = convertWithStream(results); - Socket socket = new Socket(); - socket.connect(new InetSocketAddress(DefaultServer.getHostAddress(), PORT)); + InetSocketAddress serverAddress = DefaultServer.getDefaultServerAddress(); Assert.assertEquals(StatusCodes.OK, result.getStatusLine().getStatusCode()); - Assert.assertEquals(socket.getLocalAddress().getHostAddress(), map.get(GenericServletConstants.LOCAL_ADDR)); - Assert.assertEquals(socket.getLocalAddress().getHostName(), map.get(GenericServletConstants.LOCAL_NAME)); - Assert.assertEquals(PORT, Integer.parseInt(map.get(GenericServletConstants.LOCAL_PORT))); + Assert.assertEquals(serverAddress.getAddress().getHostAddress(), map.get(GenericServletConstants.LOCAL_ADDR)); + Assert.assertEquals(serverAddress.getAddress().getHostName(), map.get(GenericServletConstants.LOCAL_NAME)); + Assert.assertEquals(serverAddress.getPort(), Integer.parseInt(map.get(GenericServletConstants.LOCAL_PORT))); Assert.assertEquals("192.0.2.10", map.get(GenericServletConstants.SERVER_NAME)); Assert.assertEquals("8888", map.get(GenericServletConstants.SERVER_PORT)); Assert.assertEquals("192.0.2.43", map.get(GenericServletConstants.REMOTE_ADDR));