diff --git a/core/src/main/java/io/undertow/client/http2/Http2ClientConnection.java b/core/src/main/java/io/undertow/client/http2/Http2ClientConnection.java index 1298d4d888..e28e38a445 100644 --- a/core/src/main/java/io/undertow/client/http2/Http2ClientConnection.java +++ b/core/src/main/java/io/undertow/client/http2/Http2ClientConnection.java @@ -465,9 +465,14 @@ public void handleEvent(Http2StreamSourceChannel channel) { Http2ClientExchange newExchange = new Http2ClientExchange(Http2ClientConnection.this, null, cr); if(!request.getPushCallback().handlePush(request, newExchange)) { + // if no push handler just reset the stream channel.sendRstStream(stream.getPushedStreamId(), Http2Channel.ERROR_REFUSED_STREAM); IoUtils.safeClose(stream); + } else if (!http2Channel.addPushPromiseStream(stream.getPushedStreamId())) { + // if invalid stream id send connection error of type PROTOCOL_ERROR as spec + channel.sendGoAway(Http2Channel.ERROR_PROTOCOL_ERROR); } else { + // add the pushed stream to current exchanges currentExchanges.put(stream.getPushedStreamId(), newExchange); } } diff --git a/core/src/main/java/io/undertow/protocols/http2/Http2Channel.java b/core/src/main/java/io/undertow/protocols/http2/Http2Channel.java index 60dc3c24c1..a8a4b6fcd1 100644 --- a/core/src/main/java/io/undertow/protocols/http2/Http2Channel.java +++ b/core/src/main/java/io/undertow/protocols/http2/Http2Channel.java @@ -160,6 +160,7 @@ public class Http2Channel extends AbstractFramedChannel update the last assigned for the server + lastAssignedStreamOtherSide = lastGoodStreamId; + } + } else if (isClient()) { + // client received push promise => update the last assigned for the client + lastAssignedStreamOtherSide = Math.max(lastAssignedStreamOtherSide, streamNo); + } + } + public synchronized Http2HeadersStreamSinkChannel sendPushPromise(int associatedStreamId, HeaderMap requestHeaders, HeaderMap responseHeaders) throws IOException { if (!isOpen()) { throw UndertowMessages.MESSAGES.channelIsClosed(); @@ -1084,7 +1143,7 @@ private void handleRstStream(int streamId) { * * @return */ - public Http2HeadersStreamSinkChannel createInitialUpgradeResponseStream() { + public synchronized Http2HeadersStreamSinkChannel createInitialUpgradeResponseStream() { if (lastGoodStreamId != 0) { throw new IllegalStateException(); } @@ -1159,9 +1218,11 @@ public String getProtocol() { private synchronized boolean isIdle(int streamNo) { if(streamNo % 2 == streamIdCounter % 2) { + // our side is controlled by us in the generated streamIdCounter return streamNo >= streamIdCounter; } else { - return streamNo > lastGoodStreamId; + // the other side should increase lastAssignedStreamOtherSide all the time + return streamNo > lastAssignedStreamOtherSide; } } diff --git a/core/src/main/java/io/undertow/server/protocol/http2/Http2ServerConnection.java b/core/src/main/java/io/undertow/server/protocol/http2/Http2ServerConnection.java index 8af6787b06..5714c7f58c 100644 --- a/core/src/main/java/io/undertow/server/protocol/http2/Http2ServerConnection.java +++ b/core/src/main/java/io/undertow/server/protocol/http2/Http2ServerConnection.java @@ -399,7 +399,10 @@ public T getAttachment(AttachmentKey key) { @Override public boolean isPushSupported() { - return channel.isPushEnabled() && !exchange.getRequestHeaders().contains(Headers.X_DISABLE_PUSH); + return channel.isPushEnabled() + && !exchange.getRequestHeaders().contains(Headers.X_DISABLE_PUSH) + // push is not supported for already pushed streams, just for peer-initiated (odd) ids + && responseChannel.getStreamId() % 2 != 0; } @Override diff --git a/servlet/src/test/java/io/undertow/servlet/test/push/PushPromisesTestCase.java b/servlet/src/test/java/io/undertow/servlet/test/push/PushPromisesTestCase.java new file mode 100644 index 0000000000..93be473765 --- /dev/null +++ b/servlet/src/test/java/io/undertow/servlet/test/push/PushPromisesTestCase.java @@ -0,0 +1,224 @@ +/* + * JBoss, Home of Professional Open Source. + * Copyright 2021 Red Hat, Inc., and individual contributors + * as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.undertow.servlet.test.push; + +import io.undertow.UndertowOptions; +import io.undertow.client.ClientCallback; +import io.undertow.client.ClientConnection; +import io.undertow.client.ClientExchange; +import io.undertow.client.ClientRequest; +import io.undertow.client.ClientResponse; +import io.undertow.client.PushCallback; +import io.undertow.client.UndertowClient; +import io.undertow.protocols.ssl.UndertowXnioSsl; +import io.undertow.server.OpenListener; +import io.undertow.server.handlers.PathHandler; +import io.undertow.server.protocol.http.AlpnOpenListener; +import io.undertow.server.protocol.http2.Http2OpenListener; +import io.undertow.servlet.api.DeploymentInfo; +import io.undertow.servlet.api.DeploymentManager; +import io.undertow.servlet.api.ServletContainer; +import io.undertow.servlet.api.ServletInfo; +import io.undertow.servlet.test.util.TestClassIntrospector; +import io.undertow.testutils.DefaultServer; +import io.undertow.testutils.ProxyIgnore; +import io.undertow.util.AttachmentKey; +import io.undertow.util.Headers; +import io.undertow.util.Methods; +import io.undertow.util.SingleByteStreamSinkConduit; +import io.undertow.util.SingleByteStreamSourceConduit; +import io.undertow.util.StatusCodes; +import io.undertow.util.StringReadChannelListener; +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.xnio.ChannelListener; +import org.xnio.ChannelListeners; +import org.xnio.IoUtils; +import org.xnio.OptionMap; +import org.xnio.Options; +import org.xnio.StreamConnection; +import org.xnio.Xnio; +import org.xnio.XnioWorker; + +/** + *

Test that checks that push promises are returned and double promises + * are avoid (a resource sent as a promise trigers another promise).

+ * + * @author rmartinc + */ +@RunWith(DefaultServer.class) +@ProxyIgnore +public class PushPromisesTestCase { + + private static final AttachmentKey RESPONSE_BODY = AttachmentKey.create(String.class); + + private static OpenListener openListener; + private static ChannelListener acceptListener; + private static XnioWorker worker; + + private static ChannelListener wrapOpenListener(final ChannelListener listener) { + return (StreamConnection channel) -> { + channel.getSinkChannel().setConduit(new SingleByteStreamSinkConduit(channel.getSinkChannel().getConduit(), 10000)); + channel.getSourceChannel().setConduit(new SingleByteStreamSourceConduit(channel.getSourceChannel().getConduit(), 10000)); + listener.handleEvent(channel); + }; + } + + @BeforeClass + public static void setup() throws Exception { + final PathHandler root = new PathHandler(); + final ServletContainer container = ServletContainer.Factory.newInstance(); + ServletInfo s = new ServletInfo("servlet", PushServlet.class) + .addMappings("/index.html", "/resources/*"); + DeploymentInfo info = new DeploymentInfo() + .setClassLoader(PushPromisesTestCase.class.getClassLoader()) + .setContextPath("/push-example") + .setClassIntrospecter(TestClassIntrospector.INSTANCE) + .setDeploymentName("push-example.war") + .addServlet(s); + + DeploymentManager manager = container.addDeployment(info); + manager.deploy(); + root.addPrefixPath(info.getContextPath(), manager.start()); + + openListener = new Http2OpenListener(DefaultServer.getBufferPool(), OptionMap.create(UndertowOptions.ENABLE_HTTP2, true, UndertowOptions.HTTP2_PADDING_SIZE, 10)); + acceptListener = ChannelListeners.openListenerAdapter(wrapOpenListener(new AlpnOpenListener(DefaultServer.getBufferPool()).addProtocol(Http2OpenListener.HTTP2, (io.undertow.server.DelegateOpenListener) openListener, 10))); + openListener.setRootHandler(root); + + DefaultServer.startSSLServer(OptionMap.EMPTY, acceptListener); + + final Xnio xnio = Xnio.getInstance(); + final XnioWorker xnioWorker = xnio.createWorker(null, + OptionMap.builder() + .set(Options.WORKER_IO_THREADS, 8) + .set(Options.TCP_NODELAY, true) + .set(Options.KEEP_ALIVE, true) + .set(Options.WORKER_NAME, "Client").getMap()); + worker = xnioWorker; + } + + @AfterClass + public static void cleanUp() throws Exception { + openListener.closeConnections(); + DefaultServer.stopSSLServer(); + } + + private PushCallback createPushCallback(final Map responses, final CountDownLatch latch) { + return new PushCallback() { + @Override + public boolean handlePush(ClientExchange originalRequest, ClientExchange pushedRequest) { + pushedRequest.setResponseListener(new ResponseListener(responses, latch)); + return true; + } + }; + } + + private ClientCallback createClientCallback(final Map responses, final CountDownLatch latch) { + return new ClientCallback() { + @Override + public void completed(final ClientExchange result) { + result.setResponseListener(new ResponseListener(responses, latch)); + result.setPushHandler(createPushCallback(responses, latch)); + } + + @Override + public void failed(IOException e) { + e.printStackTrace(); + latch.countDown(); + } + }; + } + + private static class ResponseListener implements ClientCallback { + + private final Map responses; + private final CountDownLatch latch; + + ResponseListener(Map responses, CountDownLatch latch) { + this.responses = responses; + this.latch = latch; + } + + @Override + public void completed(final ClientExchange result) { + responses.put(result.getRequest().getPath(), result.getResponse()); + new StringReadChannelListener(result.getConnection().getBufferPool()) { + + @Override + protected void stringDone(String string) { + result.getResponse().putAttachment(RESPONSE_BODY, string); + latch.countDown(); + } + + @Override + protected void error(IOException e) { + e.printStackTrace(); + latch.countDown(); + } + }.setup(result.getResponseChannel()); + } + + @Override + public void failed(IOException e) { + e.printStackTrace(); + latch.countDown(); + } + } + + @Test + public void testPushPromises() throws Exception { + URI uri = new URI(DefaultServer.getDefaultServerSSLAddress()); + final UndertowClient client = UndertowClient.getInstance(); + final Map responses = new ConcurrentHashMap<>(); + final CountDownLatch latch = new CountDownLatch(3); + final ClientConnection connection = client.connect(uri, worker, new UndertowXnioSsl(worker.getXnio(), OptionMap.EMPTY, DefaultServer.getClientSSLContext()), DefaultServer.getBufferPool(), OptionMap.create(UndertowOptions.ENABLE_HTTP2, true)) + .get(); + try { + connection.getIoThread().execute(new Runnable() { + @Override + public void run() { + final ClientRequest request = new ClientRequest().setMethod(Methods.GET).setPath("/push-example/index.html"); + request.getRequestHeaders().put(Headers.HOST, DefaultServer.getHostAddress()); + connection.sendRequest(request, createClientCallback(responses, latch)); + } + }); + latch.await(10, TimeUnit.SECONDS); + Assert.assertEquals(3, responses.size()); + Assert.assertTrue(responses.containsKey("/push-example/index.html")); + Assert.assertEquals(StatusCodes.OK, responses.get("/push-example/index.html").getResponseCode()); + Assert.assertNotNull(responses.get("/push-example/index.html").getAttachment(RESPONSE_BODY)); + Assert.assertTrue(responses.containsKey("/push-example/resources/one.js")); + Assert.assertEquals(StatusCodes.OK, responses.get("/push-example/resources/one.js").getResponseCode()); + Assert.assertNotNull(responses.get("/push-example/resources/one.js").getAttachment(RESPONSE_BODY)); + Assert.assertTrue(responses.containsKey("/push-example/resources/one.css")); + Assert.assertEquals(StatusCodes.OK, responses.get("/push-example/resources/one.css").getResponseCode()); + Assert.assertNotNull(responses.get("/push-example/resources/one.css").getAttachment(RESPONSE_BODY)); + } finally { + IoUtils.safeClose(connection); + } + } +} diff --git a/servlet/src/test/java/io/undertow/servlet/test/push/PushServlet.java b/servlet/src/test/java/io/undertow/servlet/test/push/PushServlet.java new file mode 100644 index 0000000000..d367d7cc67 --- /dev/null +++ b/servlet/src/test/java/io/undertow/servlet/test/push/PushServlet.java @@ -0,0 +1,90 @@ +/* + * JBoss, Home of Professional Open Source. + * Copyright 2021 Red Hat, Inc., and individual contributors + * as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.undertow.servlet.test.push; + +import java.io.IOException; +import java.util.Base64; +import javax.servlet.ServletException; +import javax.servlet.ServletOutputStream; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.PushBuilder; + +/** + *

Simple servlet that pushes resources for index.html and *.css. The idea + * is that double promises are not sent as they are forbidden by the spec.

+ * + * @author rmartinc + */ +public class PushServlet extends HttpServlet { + + // A simple blue pixel in PNG format + private static final String PNG_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="; + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { + if (request.getServletPath().endsWith(".html") || request.getPathInfo().endsWith(".html")) { + PushBuilder pushBuilder = request.newPushBuilder(); + if (pushBuilder != null) { + // pushing css and js in advance + pushBuilder.path("resources/one.css").push(); + pushBuilder.path("resources/one.js").push(); + } + try (ServletOutputStream out = response.getOutputStream()) { + out.println(""); + out.println(" "); + out.println(" "); + out.println(" "); + out.println(" "); + out.println(" "); + out.println(" PUSH PROMISES"); + out.println(" "); + out.println(""); + } + } else if (request.getPathInfo().endsWith(".css")) { + PushBuilder pushBuilder = request.newPushBuilder(); + if (pushBuilder != null) { + // pushing images in advance + pushBuilder.path("resources/one.png").push(); + } + response.setContentType("text/css"); + try (ServletOutputStream out = response.getOutputStream()) { + out.println("body, html {"); + out.println(" height: 100%;"); + out.println(" margin: 0;"); + out.println(" background-image: url(\"one.png\");"); + out.println(" background-repeat: repeat;"); + out.println("}"); + } + } else if (request.getPathInfo().endsWith(".js")) { + response.setContentType("application/javascript"); + try (ServletOutputStream out = response.getOutputStream()) { + out.println("console.log('loading js file ' + location.pathname);"); + } + } else if (request.getPathInfo().endsWith(".png")) { + byte[] bytes = Base64.getDecoder().decode(PNG_BASE64); + response.setContentType("image/png"); + try (ServletOutputStream out = response.getOutputStream()) { + out.write(bytes); + } + } else { + throw new ServletException("Invalid request"); + } + } +}