diff --git a/android-interop-testing/build.gradle b/android-interop-testing/build.gradle index 17551465f05..ee163ef190c 100644 --- a/android-interop-testing/build.gradle +++ b/android-interop-testing/build.gradle @@ -17,6 +17,7 @@ android { srcDirs += "${projectDir}/../interop-testing/src/main/java/" setIncludes(["io/grpc/android/integrationtest/**", "io/grpc/testing/integration/AbstractInteropTest.java", + "io/grpc/testing/integration/TestCases.java", "io/grpc/testing/integration/TestServiceImpl.java", "io/grpc/testing/integration/Util.java"]) } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java index 2d16065254a..f058a9190cf 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java @@ -59,7 +59,8 @@ public enum TestCases { RPC_SOAK("sends 'soak_iterations' large_unary rpcs in a loop, each on the same channel"), CHANNEL_SOAK("sends 'soak_iterations' large_unary rpcs in a loop, each on a new channel"), ORCA_PER_RPC("report backend metrics per query"), - ORCA_OOB("report backend metrics out-of-band"); + ORCA_OOB("report backend metrics out-of-band"), + MCS_CS("max concurrent streaming connection scaling"); private final String description; diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java index 125d876b705..7dbe8483d91 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java @@ -70,6 +70,7 @@ import io.grpc.testing.integration.Messages.TestOrcaReport; import java.io.File; import java.io.FileInputStream; +import java.io.IOException; import java.io.InputStream; import java.nio.charset.Charset; import java.util.Arrays; @@ -563,7 +564,46 @@ private void runTest(TestCases testCase) throws Exception { tester.testOrcaOob(); break; } - + + case MCS_CS: { + ChannelCredentials channelCredentials; + if (useTls) { + if (!useTestCa) { + channelCredentials = TlsChannelCredentials.create(); + } else { + try { + channelCredentials = TlsChannelCredentials.newBuilder() + .trustManager(TlsTesting.loadCert("ca.pem")) + .build(); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + } else { + channelCredentials = InsecureChannelCredentials.create(); + } + ManagedChannelBuilder channelBuilder; + if (serverPort == 0) { + channelBuilder = Grpc.newChannelBuilder(serverHost, channelCredentials); + } else { + channelBuilder = + Grpc.newChannelBuilderForAddress(serverHost, serverPort, channelCredentials); + } + if (serverHostOverride != null) { + channelBuilder.overrideAuthority(serverHostOverride); + } + channelBuilder.disableServiceConfigLookUp(); + try { + @SuppressWarnings("unchecked") + Map serviceConfigMap = (Map) JsonParser.parse( + "{\"connection_scaling\":{\"max_connections_per_subchannel\": 2}}"); + channelBuilder.defaultServiceConfig(serviceConfigMap); + } catch (IOException e) { + throw new RuntimeException(e); + } + tester.testMcs(TestServiceGrpc.newStub(channelBuilder.build())); + break; + } default: throw new IllegalArgumentException("Unknown test case: " + testCase); } @@ -596,6 +636,7 @@ private ClientInterceptor maybeCreateAdditionalMetadataInterceptor( } private class Tester extends AbstractInteropTest { + @Override protected ManagedChannelBuilder createChannelBuilder() { boolean useGeneric = false; @@ -979,31 +1020,16 @@ public void testOrcaOob() throws Exception { .build(); final int retryLimit = 5; - BlockingQueue queue = new LinkedBlockingQueue<>(); - final Object lastItem = new Object(); + StreamingOutputCallResponseObserver streamingOutputCallResponseObserver = + new StreamingOutputCallResponseObserver(); StreamObserver streamObserver = - asyncStub.fullDuplexCall(new StreamObserver() { - - @Override - public void onNext(StreamingOutputCallResponse value) { - queue.add(value); - } - - @Override - public void onError(Throwable t) { - queue.add(t); - } - - @Override - public void onCompleted() { - queue.add(lastItem); - } - }); + asyncStub.fullDuplexCall(streamingOutputCallResponseObserver); streamObserver.onNext(StreamingOutputCallRequest.newBuilder() .setOrcaOobReport(answer) .addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build()); - assertThat(queue.take()).isInstanceOf(StreamingOutputCallResponse.class); + assertThat(streamingOutputCallResponseObserver.take()) + .isInstanceOf(StreamingOutputCallResponse.class); int i = 0; for (; i < retryLimit; i++) { Thread.sleep(1000); @@ -1016,7 +1042,8 @@ public void onCompleted() { streamObserver.onNext(StreamingOutputCallRequest.newBuilder() .setOrcaOobReport(answer2) .addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build()); - assertThat(queue.take()).isInstanceOf(StreamingOutputCallResponse.class); + assertThat(streamingOutputCallResponseObserver.take()) + .isInstanceOf(StreamingOutputCallResponse.class); for (i = 0; i < retryLimit; i++) { Thread.sleep(1000); @@ -1027,7 +1054,7 @@ public void onCompleted() { } assertThat(i).isLessThan(retryLimit); streamObserver.onCompleted(); - assertThat(queue.take()).isSameInstanceAs(lastItem); + assertThat(streamingOutputCallResponseObserver.verifiedCompleted()).isTrue(); } @Override @@ -1054,6 +1081,84 @@ protected ServerBuilder getHandshakerServerBuilder() { protected int operationTimeoutMillis() { return 15000; } + + class StreamingOutputCallResponseObserver implements + StreamObserver { + private final Object lastItem = new Object(); + private final BlockingQueue queue = new LinkedBlockingQueue<>(); + + @Override + public void onNext(StreamingOutputCallResponse value) { + queue.add(value); + } + + @Override + public void onError(Throwable t) { + queue.add(t); + } + + @Override + public void onCompleted() { + queue.add(lastItem); + } + + Object take() throws InterruptedException { + return queue.take(); + } + + boolean verifiedCompleted() throws InterruptedException { + return queue.take() == lastItem; + } + } + + public void testMcs(TestServiceGrpc.TestServiceStub asyncStub) throws Exception { + StreamingOutputCallResponseObserver responseObserver1 = + new StreamingOutputCallResponseObserver(); + StreamObserver streamObserver1 = + asyncStub.fullDuplexCall(responseObserver1); + StreamingOutputCallRequest request = StreamingOutputCallRequest.newBuilder() + .addResponseParameters(ResponseParameters.newBuilder() + .setSendClientSocketAddressInResponse( + Messages.BoolValue.newBuilder().setValue(true).build()) + .build()) + .build(); + streamObserver1.onNext(request); + Object responseObj = responseObserver1.take(); + StreamingOutputCallResponse callResponse = (StreamingOutputCallResponse) responseObj; + String clientSocketAddressInCall1 = callResponse.getClientSocketAddress(); + assertThat(clientSocketAddressInCall1).isNotEmpty(); + + StreamingOutputCallResponseObserver responseObserver2 = + new StreamingOutputCallResponseObserver(); + StreamObserver streamObserver2 = + asyncStub.fullDuplexCall(responseObserver2); + streamObserver2.onNext(request); + callResponse = (StreamingOutputCallResponse) responseObserver2.take(); + String clientSocketAddressInCall2 = callResponse.getClientSocketAddress(); + + assertThat(clientSocketAddressInCall1).isEqualTo(clientSocketAddressInCall2); + + // The first connection is at max rpc call count of 2, so the 3rd rpc will cause a new + // connection to be created in the same subchannel and not get queued. + StreamingOutputCallResponseObserver responseObserver3 = + new StreamingOutputCallResponseObserver(); + StreamObserver streamObserver3 = + asyncStub.fullDuplexCall(responseObserver3); + streamObserver3.onNext(request); + callResponse = (StreamingOutputCallResponse) responseObserver3.take(); + String clientSocketAddressInCall3 = callResponse.getClientSocketAddress(); + + // This assertion is currently failing because connection scaling when MCS limit has been + // reached is not yet implemented in gRPC Java. + assertThat(clientSocketAddressInCall3).isNotEqualTo(clientSocketAddressInCall1); + + streamObserver1.onCompleted(); + assertThat(responseObserver1.verifiedCompleted()).isTrue(); + streamObserver2.onCompleted(); + assertThat(responseObserver2.verifiedCompleted()).isTrue(); + streamObserver3.onCompleted(); + assertThat(responseObserver3.verifiedCompleted()).isTrue(); + } } private static String validTestCasesHelpText() { diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java index a9ee9382495..27cfbb9fdb2 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java @@ -16,13 +16,18 @@ package io.grpc.testing.integration; +import static io.grpc.Grpc.TRANSPORT_ATTR_REMOTE_ADDR; + import com.google.common.base.Preconditions; import com.google.common.collect.Queues; import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.protobuf.ByteString; +import io.grpc.Context; +import io.grpc.Contexts; import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.Metadata; import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.Status; @@ -42,10 +47,12 @@ import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; import io.grpc.testing.integration.Messages.TestOrcaReport; import io.grpc.testing.integration.TestServiceGrpc.AsyncService; +import java.net.SocketAddress; import java.util.ArrayDeque; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Queue; @@ -61,8 +68,8 @@ * sent in response streams. */ public class TestServiceImpl implements io.grpc.BindableService, AsyncService { + static Context.Key PEER_ADDRESS_CONTEXT_KEY = Context.key("peer-address"); private final Random random = new Random(); - private final ScheduledExecutorService executor; private final ByteString compressableBuffer; private final MetricRecorder metricRecorder; @@ -235,9 +242,27 @@ public void onNext(StreamingOutputCallRequest request) { .asRuntimeException()); return; } + if (whetherSendClientSocketAddressInResponse(request)) { + responseObserver.onNext( + StreamingOutputCallResponse.newBuilder() + .setClientSocketAddress(PEER_ADDRESS_CONTEXT_KEY.get().toString()) + .build()); + return; + } dispatcher.enqueue(toChunkQueue(request)); } + private boolean whetherSendClientSocketAddressInResponse(StreamingOutputCallRequest request) { + Iterator responseParametersIterator = + request.getResponseParametersList().iterator(); + while (responseParametersIterator.hasNext()) { + if (responseParametersIterator.next().getSendClientSocketAddressInResponse().getValue()) { + return true; + } + } + return false; + } + @Override public void onCompleted() { if (oobTestLocked) { @@ -507,7 +532,8 @@ public static List interceptors() { return Arrays.asList( echoRequestHeadersInterceptor(Util.METADATA_KEY), echoRequestMetadataInHeaders(Util.ECHO_INITIAL_METADATA_KEY), - echoRequestMetadataInTrailers(Util.ECHO_TRAILING_METADATA_KEY)); + echoRequestMetadataInTrailers(Util.ECHO_TRAILING_METADATA_KEY), + new McsScalingTestcaseInterceptor()); } /** @@ -539,6 +565,22 @@ public void close(Status status, Metadata trailers) { }; } + static class McsScalingTestcaseInterceptor implements ServerInterceptor { + @Override + public Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + SocketAddress peerAddress = call.getAttributes().get(TRANSPORT_ATTR_REMOTE_ADDR); + + // Create a new context with the peer address value + Context newContext = Context.current().withValue(PEER_ADDRESS_CONTEXT_KEY, peerAddress); + try { + return Contexts.interceptCall(newContext, call, headers, next); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + } + /** * Echoes request headers with the specified key(s) from a client into response headers only. */ diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java index fc4cdf9178f..0227482f3f5 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java @@ -75,6 +75,7 @@ public void run() { private int port = 8080; private boolean useTls = true; private boolean useAlts = false; + private boolean setMcsLimit = false; private ScheduledExecutorService executor; private Server server; @@ -118,6 +119,10 @@ void parseArgs(String[] args) { usage = true; break; } + } else if ("set_max_concurrent_streams_limit".equals(key)) { + setMcsLimit = Boolean.parseBoolean(value); + // TODO: Make Netty server builder usable for IPV6 as well (not limited to MCS handling) + addressType = Util.AddressType.IPV4; // To use NettyServerBuilder } else { System.err.println("Unknown argument: " + key); usage = true; @@ -141,6 +146,8 @@ void parseArgs(String[] args) { + "\n for testing. Only effective when --use_alts=true." + "\n --address_type=IPV4|IPV6|IPV4_IPV6" + "\n What type of addresses to listen on. Default IPV4_IPV6" + + "\n --set_max_concurrent_streams_limit" + + "\n Whether to set the maximum concurrent streams limit" ); System.exit(1); } @@ -186,6 +193,9 @@ void start() throws Exception { if (v4Address != null && !v4Address.equals(localV4Address)) { ((NettyServerBuilder) serverBuilder).addListenAddress(v4Address); } + if (setMcsLimit) { + ((NettyServerBuilder) serverBuilder).maxConcurrentCallsPerConnection(2); + } break; case IPV6: List v6Addresses = Util.getV6Addresses(port); diff --git a/interop-testing/src/main/proto/grpc/testing/messages.proto b/interop-testing/src/main/proto/grpc/testing/messages.proto index fbcb6b4ce9b..74deae15afc 100644 --- a/interop-testing/src/main/proto/grpc/testing/messages.proto +++ b/interop-testing/src/main/proto/grpc/testing/messages.proto @@ -159,6 +159,10 @@ message ResponseParameters { // implement the full compression tests by introspecting the call to verify // the response's compression status. BoolValue compressed = 3; + + // Whether to request the server to send the requesting client's socket + // address in the response. + BoolValue send_client_socket_address_in_response = 4; } // Server-streaming request. @@ -186,6 +190,9 @@ message StreamingOutputCallRequest { message StreamingOutputCallResponse { // Payload to increase response size. Payload payload = 1; + + // The client's socket address if requested. + string client_socket_address = 2; } // For reconnect interop test only. diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java index ab32d584e7c..4f035749111 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java @@ -67,7 +67,8 @@ public void testCaseNamesShouldMapToEnums() { "cancel_after_first_response", "timeout_on_sleeping_server", "orca_per_rpc", - "orca_oob" + "orca_oob", + "mcs_cs", }; // additional test cases