Skip to content

Commit

Permalink
Merge pull request #1085 from rmartinc/UNDERTOW-1877
Browse files Browse the repository at this point in the history
[UNDERTOW-1877] HTTP2 implementation returns PUSH PROMISES for frames…
  • Loading branch information
fl4via authored Apr 15, 2021
2 parents b5a0132 + cde915a commit 43244cc
Show file tree
Hide file tree
Showing 5 changed files with 389 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -465,9 +465,14 @@ public void handleEvent(Http2StreamSourceChannel channel) {
Http2ClientExchange newExchange = new Http2ClientExchange(Http2ClientConnection.this, null, cr);

if(!request.getPushCallback().handlePush(request, newExchange)) {
// if no push handler just reset the stream
channel.sendRstStream(stream.getPushedStreamId(), Http2Channel.ERROR_REFUSED_STREAM);
IoUtils.safeClose(stream);
} else if (!http2Channel.addPushPromiseStream(stream.getPushedStreamId())) {
// if invalid stream id send connection error of type PROTOCOL_ERROR as spec
channel.sendGoAway(Http2Channel.ERROR_PROTOCOL_ERROR);
} else {
// add the pushed stream to current exchanges
currentExchanges.put(stream.getPushedStreamId(), newExchange);
}
}
Expand Down
71 changes: 66 additions & 5 deletions core/src/main/java/io/undertow/protocols/http2/Http2Channel.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ public class Http2Channel extends AbstractFramedChannel<Http2Channel, AbstractHt

private int streamIdCounter;
private int lastGoodStreamId;
private int lastAssignedStreamOtherSide;

private final HpackDecoder decoder;
private final HpackEncoder encoder;
Expand Down Expand Up @@ -403,7 +404,7 @@ protected AbstractHttp2StreamSourceChannel createChannelImpl(FrameHeaderData fra
}
}
} else {
if(frameParser.streamId < lastGoodStreamId) {
if(frameParser.streamId < getLastAssignedStreamOtherSide()) {
sendGoAway(ERROR_PROTOCOL_ERROR);
frameData.close();
return null;
Expand All @@ -417,7 +418,8 @@ protected AbstractHttp2StreamSourceChannel createChannelImpl(FrameHeaderData fra
Http2HeadersParser parser = (Http2HeadersParser) frameParser.parser;

channel = new Http2StreamSourceChannel(this, frameData, frameHeaderData.getFrameLength(), parser.getHeaderMap(), frameParser.streamId);
lastGoodStreamId = Math.max(lastGoodStreamId, frameParser.streamId);

updateStreamIdsCountersInHeaders(frameParser.streamId);

StreamHolder holder = currentStreams.get(frameParser.streamId);
if(holder == null) {
Expand Down Expand Up @@ -819,7 +821,7 @@ public void sendGoAway(int status, final ChannelExceptionHandler<AbstractHttp2St
if(UndertowLogger.REQUEST_IO_LOGGER.isTraceEnabled()) {
UndertowLogger.REQUEST_IO_LOGGER.tracef(new ClosedChannelException(), "Sending goaway on channel %s", this);
}
Http2GoAwayStreamSinkChannel goAway = new Http2GoAwayStreamSinkChannel(this, status, lastGoodStreamId);
Http2GoAwayStreamSinkChannel goAway = new Http2GoAwayStreamSinkChannel(this, status, getLastGoodStreamId());
try {
goAway.shutdownWrites();
if (!goAway.flush()) {
Expand Down Expand Up @@ -899,6 +901,63 @@ public synchronized Http2HeadersStreamSinkChannel createStream(HeaderMap request
return http2SynStreamStreamSinkChannel;
}

/**
* Adds a received pushed stream into the current streams for a client. The
* stream is added into the currentStream and lastAssignedStreamOtherSide is incremented.
*
* @param pushedStreamId The pushed stream returned by the server
* @return true if pushedStreamId can be added, false if invalid
* @throws IOException General error like not being a client or odd stream id
*/
public synchronized boolean addPushPromiseStream(int pushedStreamId) throws IOException {
if (!isClient() || pushedStreamId % 2 != 0) {
throw UndertowMessages.MESSAGES.pushPromiseCanOnlyBeCreatedByServer();
}
if (!isOpen()) {
throw UndertowMessages.MESSAGES.channelIsClosed();
}
if (!isIdle(pushedStreamId)) {
UndertowLogger.REQUEST_IO_LOGGER.debugf("Non idle streamId %d received from the server as a pushed stream.", pushedStreamId);
return false;
}
StreamHolder holder = new StreamHolder((Http2HeadersStreamSinkChannel) null);
holder.sinkClosed = true;
lastAssignedStreamOtherSide = Math.max(lastAssignedStreamOtherSide, pushedStreamId);
currentStreams.put(pushedStreamId, holder);
return true;
}

private synchronized int getLastAssignedStreamOtherSide() {
return lastAssignedStreamOtherSide;
}

private synchronized int getLastGoodStreamId() {
return lastGoodStreamId;
}

/**
* Updates the lastGoodStreamId (last request ID to send in goaway frames),
* and lastAssignedStreamOtherSide (the last received streamId from the other
* side to check if it's idle). The lastAssignedStreamOtherSide in a server
* is the same as lastGoodStreamId but in a client push promises can be
* received and check for idle is different.
*
* @param streamNo The received streamId for the client or the server
*/
private synchronized void updateStreamIdsCountersInHeaders(int streamNo) {
if (streamNo % 2 != 0) {
// the last good stream is always the last client ID sent by the client or received by the server
lastGoodStreamId = Math.max(lastGoodStreamId, streamNo);
if (!isClient()) {
// server received client request ID => update the last assigned for the server
lastAssignedStreamOtherSide = lastGoodStreamId;
}
} else if (isClient()) {
// client received push promise => update the last assigned for the client
lastAssignedStreamOtherSide = Math.max(lastAssignedStreamOtherSide, streamNo);
}
}

public synchronized Http2HeadersStreamSinkChannel sendPushPromise(int associatedStreamId, HeaderMap requestHeaders, HeaderMap responseHeaders) throws IOException {
if (!isOpen()) {
throw UndertowMessages.MESSAGES.channelIsClosed();
Expand Down Expand Up @@ -1084,7 +1143,7 @@ private void handleRstStream(int streamId) {
*
* @return
*/
public Http2HeadersStreamSinkChannel createInitialUpgradeResponseStream() {
public synchronized Http2HeadersStreamSinkChannel createInitialUpgradeResponseStream() {
if (lastGoodStreamId != 0) {
throw new IllegalStateException();
}
Expand Down Expand Up @@ -1159,9 +1218,11 @@ public String getProtocol() {

private synchronized boolean isIdle(int streamNo) {
if(streamNo % 2 == streamIdCounter % 2) {
// our side is controlled by us in the generated streamIdCounter
return streamNo >= streamIdCounter;
} else {
return streamNo > lastGoodStreamId;
// the other side should increase lastAssignedStreamOtherSide all the time
return streamNo > lastAssignedStreamOtherSide;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,10 @@ public <T> T getAttachment(AttachmentKey<T> key) {

@Override
public boolean isPushSupported() {
return channel.isPushEnabled() && !exchange.getRequestHeaders().contains(Headers.X_DISABLE_PUSH);
return channel.isPushEnabled()
&& !exchange.getRequestHeaders().contains(Headers.X_DISABLE_PUSH)
// push is not supported for already pushed streams, just for peer-initiated (odd) ids
&& responseChannel.getStreamId() % 2 != 0;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
/*
* JBoss, Home of Professional Open Source.
* Copyright 2021 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.undertow.servlet.test.push;

import io.undertow.UndertowOptions;
import io.undertow.client.ClientCallback;
import io.undertow.client.ClientConnection;
import io.undertow.client.ClientExchange;
import io.undertow.client.ClientRequest;
import io.undertow.client.ClientResponse;
import io.undertow.client.PushCallback;
import io.undertow.client.UndertowClient;
import io.undertow.protocols.ssl.UndertowXnioSsl;
import io.undertow.server.OpenListener;
import io.undertow.server.handlers.PathHandler;
import io.undertow.server.protocol.http.AlpnOpenListener;
import io.undertow.server.protocol.http2.Http2OpenListener;
import io.undertow.servlet.api.DeploymentInfo;
import io.undertow.servlet.api.DeploymentManager;
import io.undertow.servlet.api.ServletContainer;
import io.undertow.servlet.api.ServletInfo;
import io.undertow.servlet.test.util.TestClassIntrospector;
import io.undertow.testutils.DefaultServer;
import io.undertow.testutils.ProxyIgnore;
import io.undertow.util.AttachmentKey;
import io.undertow.util.Headers;
import io.undertow.util.Methods;
import io.undertow.util.SingleByteStreamSinkConduit;
import io.undertow.util.SingleByteStreamSourceConduit;
import io.undertow.util.StatusCodes;
import io.undertow.util.StringReadChannelListener;
import java.io.IOException;
import java.net.URI;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.xnio.ChannelListener;
import org.xnio.ChannelListeners;
import org.xnio.IoUtils;
import org.xnio.OptionMap;
import org.xnio.Options;
import org.xnio.StreamConnection;
import org.xnio.Xnio;
import org.xnio.XnioWorker;

/**
* <p>Test that checks that push promises are returned and double promises
* are avoid (a resource sent as a promise trigers another promise).</p>
*
* @author rmartinc
*/
@RunWith(DefaultServer.class)
@ProxyIgnore
public class PushPromisesTestCase {

private static final AttachmentKey<String> RESPONSE_BODY = AttachmentKey.create(String.class);

private static OpenListener openListener;
private static ChannelListener acceptListener;
private static XnioWorker worker;

private static ChannelListener<StreamConnection> wrapOpenListener(final ChannelListener<StreamConnection> listener) {
return (StreamConnection channel) -> {
channel.getSinkChannel().setConduit(new SingleByteStreamSinkConduit(channel.getSinkChannel().getConduit(), 10000));
channel.getSourceChannel().setConduit(new SingleByteStreamSourceConduit(channel.getSourceChannel().getConduit(), 10000));
listener.handleEvent(channel);
};
}

@BeforeClass
public static void setup() throws Exception {
final PathHandler root = new PathHandler();
final ServletContainer container = ServletContainer.Factory.newInstance();
ServletInfo s = new ServletInfo("servlet", PushServlet.class)
.addMappings("/index.html", "/resources/*");
DeploymentInfo info = new DeploymentInfo()
.setClassLoader(PushPromisesTestCase.class.getClassLoader())
.setContextPath("/push-example")
.setClassIntrospecter(TestClassIntrospector.INSTANCE)
.setDeploymentName("push-example.war")
.addServlet(s);

DeploymentManager manager = container.addDeployment(info);
manager.deploy();
root.addPrefixPath(info.getContextPath(), manager.start());

openListener = new Http2OpenListener(DefaultServer.getBufferPool(), OptionMap.create(UndertowOptions.ENABLE_HTTP2, true, UndertowOptions.HTTP2_PADDING_SIZE, 10));
acceptListener = ChannelListeners.openListenerAdapter(wrapOpenListener(new AlpnOpenListener(DefaultServer.getBufferPool()).addProtocol(Http2OpenListener.HTTP2, (io.undertow.server.DelegateOpenListener) openListener, 10)));
openListener.setRootHandler(root);

DefaultServer.startSSLServer(OptionMap.EMPTY, acceptListener);

final Xnio xnio = Xnio.getInstance();
final XnioWorker xnioWorker = xnio.createWorker(null,
OptionMap.builder()
.set(Options.WORKER_IO_THREADS, 8)
.set(Options.TCP_NODELAY, true)
.set(Options.KEEP_ALIVE, true)
.set(Options.WORKER_NAME, "Client").getMap());
worker = xnioWorker;
}

@AfterClass
public static void cleanUp() throws Exception {
openListener.closeConnections();
DefaultServer.stopSSLServer();
}

private PushCallback createPushCallback(final Map<String, ClientResponse> responses, final CountDownLatch latch) {
return new PushCallback() {
@Override
public boolean handlePush(ClientExchange originalRequest, ClientExchange pushedRequest) {
pushedRequest.setResponseListener(new ResponseListener(responses, latch));
return true;
}
};
}

private ClientCallback<ClientExchange> createClientCallback(final Map<String, ClientResponse> responses, final CountDownLatch latch) {
return new ClientCallback<ClientExchange>() {
@Override
public void completed(final ClientExchange result) {
result.setResponseListener(new ResponseListener(responses, latch));
result.setPushHandler(createPushCallback(responses, latch));
}

@Override
public void failed(IOException e) {
e.printStackTrace();
latch.countDown();
}
};
}

private static class ResponseListener implements ClientCallback<ClientExchange> {

private final Map<String, ClientResponse> responses;
private final CountDownLatch latch;

ResponseListener(Map<String, ClientResponse> responses, CountDownLatch latch) {
this.responses = responses;
this.latch = latch;
}

@Override
public void completed(final ClientExchange result) {
responses.put(result.getRequest().getPath(), result.getResponse());
new StringReadChannelListener(result.getConnection().getBufferPool()) {

@Override
protected void stringDone(String string) {
result.getResponse().putAttachment(RESPONSE_BODY, string);
latch.countDown();
}

@Override
protected void error(IOException e) {
e.printStackTrace();
latch.countDown();
}
}.setup(result.getResponseChannel());
}

@Override
public void failed(IOException e) {
e.printStackTrace();
latch.countDown();
}
}

@Test
public void testPushPromises() throws Exception {
URI uri = new URI(DefaultServer.getDefaultServerSSLAddress());
final UndertowClient client = UndertowClient.getInstance();
final Map<String, ClientResponse> responses = new ConcurrentHashMap<>();
final CountDownLatch latch = new CountDownLatch(3);
final ClientConnection connection = client.connect(uri, worker, new UndertowXnioSsl(worker.getXnio(), OptionMap.EMPTY, DefaultServer.getClientSSLContext()), DefaultServer.getBufferPool(), OptionMap.create(UndertowOptions.ENABLE_HTTP2, true))
.get();
try {
connection.getIoThread().execute(new Runnable() {
@Override
public void run() {
final ClientRequest request = new ClientRequest().setMethod(Methods.GET).setPath("/push-example/index.html");
request.getRequestHeaders().put(Headers.HOST, DefaultServer.getHostAddress());
connection.sendRequest(request, createClientCallback(responses, latch));
}
});
latch.await(10, TimeUnit.SECONDS);
Assert.assertEquals(3, responses.size());
Assert.assertTrue(responses.containsKey("/push-example/index.html"));
Assert.assertEquals(StatusCodes.OK, responses.get("/push-example/index.html").getResponseCode());
Assert.assertNotNull(responses.get("/push-example/index.html").getAttachment(RESPONSE_BODY));
Assert.assertTrue(responses.containsKey("/push-example/resources/one.js"));
Assert.assertEquals(StatusCodes.OK, responses.get("/push-example/resources/one.js").getResponseCode());
Assert.assertNotNull(responses.get("/push-example/resources/one.js").getAttachment(RESPONSE_BODY));
Assert.assertTrue(responses.containsKey("/push-example/resources/one.css"));
Assert.assertEquals(StatusCodes.OK, responses.get("/push-example/resources/one.css").getResponseCode());
Assert.assertNotNull(responses.get("/push-example/resources/one.css").getAttachment(RESPONSE_BODY));
} finally {
IoUtils.safeClose(connection);
}
}
}
Loading

0 comments on commit 43244cc

Please sign in to comment.